API Documentation#
This documentation covers PyTensor module-wise. This is suited to finding the Types and Ops that you can use to build and compile expression graphs.
Modules#
compile– Transforming Expression Graphs to Functionsconfig– PyTensor Configurationd3viz– d3viz: Interactive visualization of PyTensor compute graphsgraph– PyTensor Graph Internalsprinting– Graph Printing and Symbolic Print Statementscan– Looping in PyTensorsparse– Symbolic Sparse Matricessparse– Sparse Optensor– Tensor operations in PyTensortyped_list– Typed Listxtensor– XTensor operations
There are also some top-level imports that you might find more convenient:
Graph#
Alias for
pytensor.compile.sharedvalue.shared()
- pytensor.function(...)[source]#
Alias for
pytensor.compile.function.function()
- pytensor.clone_replace(...)[source]#
Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding substitutions.
- Parameters:
output – PyTensor expression that represents the computational graph.
replace – Dictionary describing which subgraphs should be replaced by what.
rebuild_kwds – Keywords to
rebuild_collect_shared.
Alias for
pytensor.graph.basic.clone_replace()
Control flow#
- pytensor.scan(...)[source]#
This function constructs and applies a
ScanOpto the provided arguments.- Parameters:
fn –
fnis a function that describes the operations involved in onestep of
scan.fnshould construct variables describing the output of one iteration step. It should expect as inputVariables representing all the slices of the input sequences and previous values of the outputs, as well as all other arguments given to scan asnon_sequences. The order in which scan passes these variables tofnis the following :all time slices of the first sequence
all time slices of the second sequence
…
all time slices of the last sequence
all past slices of the first output
all past slices of the second output
…
all past slices of the last output
- all other arguments (the list given as
non_sequencesto scan)
- all other arguments (the list given as
The order of the sequences is the same as the one in the list
sequencesgiven toscan. The order of the outputs is the same as the order ofoutputs_info. For any sequence or output the order of the time slices is the same as the one in which they have been given as taps. For example if one writes the following :scan( fn, sequences=[ dict(input=Sequence1, taps=[-3, 2, -1]), Sequence2, dict(input=Sequence3, taps=3), ], outputs_info=[ dict(initial=Output1, taps=[-3, -5]), dict(initial=Output2, taps=None), Output3, ], non_sequences=[Argument1, Argument2], )
fnshould expect the following arguments in this given order:sequence1[t-3]sequence1[t+2]sequence1[t-1]sequence2[t]sequence3[t+3]output1[t-3]output1[t-5]output3[t-1]argument1argument2
The list of
non_sequencescan also contain shared variables used in the function, thoughscanis able to figure those out on its own so they can be skipped. For the clarity of the code we recommend though to provide them toscan. To some extendscancan also figure out othernon sequences(not shared) even if not passed toscan(but used byfn). A simple example of this would be :import pytensor.tensor as pt W = pt.matrix() W_2 = W**2
- def f(x):
return pt.dot(x, W_2)
The function
fnis expected to return two things. One is a list of outputs ordered in the same order asoutputs_info, with the difference that there should be only one output variable per output initial state (even if no tap value is used). Secondlyfnshould return an update dictionary (that tells how to update any shared variable after each iteration step). The dictionary can optionally be given as a list of tuples. There is no constraint on the order of these two list,fncan return either(outputs_list, update_dictionary)or(update_dictionary, outputs_list)or just one of the two (in case the other is empty).To use
scanas awhileloop, the user needs to change the functionfnsuch that also a stopping condition is returned. To do so, one needs to wrap the condition in anuntilclass. The condition should be returned as a third element, for example:... return [y1_t, y2_t], {x: x + 1}, until(x < 50)
Note that a number of steps–considered in here as the maximum number of steps–is still required even though a condition is passed. It is used to allocate memory if needed.
- sequences
sequencesis the list ofVariables ordicts describing the sequencesscanhas to iterate over. If a sequence is given as wrapped in adict, then a set of optional information can be provided about the sequence. Thedictshould have the following keys:input(mandatory) –Variablerepresenting the sequence.taps– Temporal taps of the sequence required byfn. They are provided as a list of integers, where a valuekimpiles that at iteration steptscan will pass tofnthe slicet+k. Default value is[0]
All
Variables in the listsequencesare automatically wrapped into adictwheretapsis set to[0]- outputs_info
outputs_infois the list ofVariables ordicts describing the initial state of the outputs computed recurrently. When the initial states are given asdicts, optional information can be provided about the output corresponding to those initial states. Thedictshould have the following keys:initial– AVariablethat represents the initial state of a given output. In case the output is not computed recursively (e.g. amap-like function) and does not require an initial state, this field can be skipped. Given that only the previous time step of the output is used byfn, the initial state should have the same shape as the output and should not involve a downcast of the data type of the output. If multiple time taps are used, the initial state should have one extra dimension that covers all the possible taps. For example if we use-5,-2and-1as past taps, at step0,fnwill require (by an abuse of notation)output[-5],output[-2]andoutput[-1]. This will be given by the initial state, which in this case should have the shape(5,) + output.shape. If thisVariablecontaining the initial state is calledinit_ytheninit_y[0]corresponds tooutput[-5].init_y[1]corresponds tooutput[-4],init_y[2]corresponds tooutput[-3],init_y[3]corresponds tooutput[-2],init_y[4]corresponds tooutput[-1]. While this order might seem strange, it comes natural from splitting an array at a given point. assume that we have a arrayx, and we choosekto be time step0. Then our initial state would bex[:k], while the output will bex[k:]. Looking at this split, elements inx[:k]are ordered exactly like those ininit_y.taps– Temporal taps of the output that will be passed tofn. They are provided as a list of negative integers, where a valuekimplies that at iteration steptscan will pass tofnthe slicet+k.
scanwill follow this logic if partial information is given:If an output is not wrapped in a
dict,scanwill wrap it in one assuming that you use only the last step of the output (i.e. it makes your tap value list equal to[-1]).If you wrap an output in a
dictand you do not provide any taps but you provide an initial state it will assume that you are using only a tap value of-1.If you wrap an output in a
dictbut you do not provide any initial state, it assumes that you are not using any form of taps.If you provide a
Noneinstead of aVariableor a emptydictscanassumes that you will not use any taps for this output (like for example in case of amap)
If
outputs_infois an emptylistorNone,scanassumes that no tap is used for any of the outputs. If information is provided just for a subset of the outputs, an exception is raised, because there is no convention on how scan should map the provided information to the outputs offn.- non_sequences
non_sequencesis the list of arguments that are passed tofnat each steps. One can choose to exclude variables used infnfrom this list, as long as they are part of the computational graph, although–for clarity–this is not encouraged.- n_steps
n_stepsis the number of steps to iterate given as anintor a scalarVariable. If any of the input sequences do not have enough elements,scanwill raise an error. If the value is0, the outputs will have0rows. Ifn_stepsis not provided,scanwill figure out the amount of steps it should run given its input sequences.n_steps < 0is not supported anymore.- truncate_gradient
truncate_gradientis the number of steps to use in truncated back-propagation through time (BPTT). If you compute gradients through aScanOp, they are computed using BPTT. By providing a different value then-1, you choose to use truncated BPTT instead of classical BPTT, where you go for onlytruncate_gradientnumber of steps back in time.- go_backwards
go_backwardsis a flag indicating ifscanshould go backwards through the sequences. If you think of each sequence as indexed by time, making this flagTruewould mean thatscangoes back in time, namely that for any sequence it starts from the end and goes towards0.- name
When profiling
scan, it is helpful to provide a name for any instance ofscan. For example, the profiler will produce an overall profile of your code as well as profiles for the computation of one step of each instance ofScan. Thenameof the instance appears in those profiles and can greatly help to disambiguate information.- mode
The mode used to compile the inner-graph. If you prefer the computations of one step of
scanto be done differently then the entire function, you can use this parameter to describe how the computations in this loop are done (seepytensor.functionfor details about possible values and their meaning).- profile
If
Trueor a non-empty string, a profile object will be created and attached to the inner graph ofScan. WhenprofileisTrue, the profiler results will use the name of theScaninstance, otherwise it will use the passed string. The profiler only collects and prints information when running the inner graph with theCVMLinker.- allow_gc
Set the value of
allow_gcfor the internal graph of theScan. If set toNone, this will use the value ofpytensor.config.scan__allow_gc.The full
Scanbehavior related to allocation is determined by this value and the flagpytensor.config.allow_gc. If the flagallow_gcisTrue(default) and thisallow_gcisFalse(default), then we letScanallocate all intermediate memory on the first iteration, and they are not garbage collected after that first iteration; this is determined byallow_gc. This can speed up allocation of the subsequent iterations. All those temporary allocations are freed at the end of all iterations; this is what the flagpytensor.config.allow_gcmeans.- strict
If
True, all the shared variables used infnmust be provided as a part ofnon_sequencesorsequences.- return_list
If
True, will always return alist, even if there is only one output.
- Returns:
tupleof the form(outputs, updates).outputsis either aVariableor alistofVariables representing the outputs in the same order as inoutputs_info.updatesis a subclass ofdictspecifying the update rules for all shared variables used inScan. Thisdictshould be passed topytensor.functionwhen you compile your function.- Return type:
tuple
Alias for
pytensor.scan.basic.scan()
Convert to Variable#
- pytensor.as_symbolic(...)[source]#
Convert
xinto an equivalent PyTensorVariable.- Parameters:
x – The object to be converted into a
Variabletype. Anumpy.ndarrayargument will not be copied, but a list of numbers will be copied to make annumpy.ndarray.name – If a new
Variableinstance is created, it will be named with this string.kwargs – Options passed to the appropriate sub-dispatch functions. For example,
ndimanddtypecan be passed whenxis annumpy.ndarrayorNumbertype.
- Raises:
TypeError – If
xcannot be converted to aVariable.
Wrap JAX functions#
- pytensor.wrap_jax(...)[source]#
Return a PyTensor-compatible function from a JAX jittable function.
This decorator wraps a JAX function so that it accepts and returns
pytensor.Variableobjects. The JAX-jittable function can accept any nested Python structure (a Pytree) as input, and might return any nested Python structure.- Parameters:
jax_function (Callable, optional) – A JAX function to be wrapped. If None, returns a decorator function.
allow_eval (bool, default=True) – Whether to allow evaluation of symbolic shapes when input shapes are not fully determined.
- Returns:
A function that wraps the given JAX function so that it can be called with pytensor.Variable inputs and returns pytensor.Variable outputs.
- Return type:
Callable
Examples
>>> import jax.numpy as jnp >>> import pytensor.tensor as pt >>> from pytensor import wrap_jax >>> @wrap_jax ... def add(x, y): ... return jnp.add(x, y) >>> x = pt.scalar("x") >>> y = pt.scalar("y") >>> result = add(x, y) >>> f = pytensor.function([x, y], [result]) >>> print(f(1, 2)) [array(3.)]
We can also pass arbitrary jax pytree structures as inputs and outputs:
>>> import jax >>> import jax.numpy as jnp >>> import pytensor.tensor as pt >>> from pytensor import wrap_jax >>> @wrap_jax ... def complex_function(x, y, scale=1.0): ... return { ... "sum": jnp.add(x, y) * scale, ... } >>> x = pt.vector("x", shape=(3,)) >>> y = pt.vector("y", shape=(3,)) >>> result = complex_function(x, y, scale=2.0) >>> f = pytensor.function([x, y], [result["sum"]])
Or Equinox modules:
>>> x = pt.tensor("x", shape=(3,)) # doctest +SKIP >>> y = pt.tensor("y", shape=(3,)) # doctest +SKIP >>> import equinox as eqx # doctest +SKIP >>> mlp = eqx.nn.MLP( ... 3, 3, 3, depth=2, activation=jnp.tanh, key=jax.random.key(0) ... ) # doctest +SKIP >>> mlp = eqx.tree_at(lambda m: m.layers[0].bias, mlp, y) # doctest +SKIP >>> @wrap_jax # doctest +SKIP ... def neural_network(x, mlp): # doctest +SKIP ... return mlp(x) # doctest +SKIP >>> out = neural_network(x, mlp) # doctest +SKIP
If the input shapes are not fully determined, and valid input shapes cannot be inferred by evaluating the inputs either, an error will be raised:
>>> import jax.numpy as jnp >>> import pytensor.tensor as pt >>> @wrap_jax ... def add(x, y): ... return jnp.add(x, y) >>> x = pt.vector("x") # shape is not fully determined >>> y = pt.vector("y") # shape is not fully determined >>> result = add(x, y) ValueError: Could not compile a function to infer example shapes. Please provide inputs with fully determined shapes by calling pt.specify_shape. ...
Alias for
pytensor.link.jax.ops.wrap_jax()
Debug#
- pytensor.dprint(...)[source]#
Print a graph as text.
Each line printed represents a
Variablein a graph. The indentation of lines corresponds to its depth in the symbolic graph. The first part of the text identifies whether it is an input or the output of someApplynode. The second part of the text is an identifier of theVariable.If a
Variableis encountered multiple times in the depth-first search, it is only printed recursively the first time. Later, just theVariableidentifier is printed.If an
Applynode has multiple outputs, then a.Nsuffix will be appended to theApplynode’s identifier, indicating to which output a line corresponds.- Parameters:
graph_like – The object(s) to be printed.
depth – Print graph to this depth (
-1for unlimited).print_type – If
True, print theTypes of eachVariablein the graph.print_shape – If
True, print the shape of eachVariablein the graph.file – When
fileextendsTextIO, print to it; whenfileis equal to"str", return a string; whenfileisNone, print tosys.stdout.id_type –
- Determines the type of identifier used for
Variables: "id": print the python id value,"int": print integer character,"CHAR": print capital character,"auto": print theVariable.auto_namevalues,"": don’t print an identifier.
- Determines the type of identifier used for
stop_on_name – When
True, if a node in the graph has a name, we don’t print anything below it.done – A
dictwhere we store the ids of printed nodes. Useful to have multiple call todebugprintshare the same ids.print_storage – If
True, this will print the storage map for PyTensor functions. When combined withallow_gc=False, after the execution of an PyTensor function, the output will show the intermediate results.used_ids – A map between nodes and their printed ids.
print_op_info – Print extra information provided by the relevant
Ops. For example, print the tap information forScaninputs and outputs.print_destroy_map – Whether to print the
destroy_maps of printed objectsprint_view_map – Whether to print the
view_maps of printed objectsprint_memory_map – Whether to set both
print_destroy_mapandprint_view_maptoTrue.print_fgraph_inputs – Print the inputs of
FunctionGraphs.
- Return type:
A string representing the printed graph, if
fileis a string, elsefile.
Alias for
pytensor.printing.debugprint()