description: An arbitrary already-batched computation, a 'primitive operation'.
View source on GitHub |
An arbitrary already-batched computation, a 'primitive operation'.
tfp.experimental.auto_batching.instructions.PrimOp(
vars_in, vars_out, function, skip_push_mask
)
These are the items of work on which auto-batching is applied. The
function
must accept and produce Tensors with a batch dimension,
and is free to stage any (batched) computation it wants.
Restriction: the function
must use the same computation substrate
as the VM backend. That is, if the VM is staging to XLA, the
function
will see XLA Tensor handles; if the VM is staging to
graph-mode TensorFlow, the function
will see TensorFlow Tensors;
etc.
The current values of the vars_out
are saved on their respective
stacks, and the results written to the new top.
The exact contract for function
is as follows:
- It will be invoked with a list of positional (only) arguments,
parallel to vars_in
.
- Each argument will be a pattern of Tensors (meaning, either one
Tensor or a (potentially nested) list or tuple of Tensors),
corresponding to the Type
of that variable.
- Each Tensor in the argument will have the dtype
and shape
given in the corresponding TensorType
, and an additional leading
batch dimension.
- Some indices in the batch dimension may contain junk data, if the
corresponding threads are not executing this instruction [this is
subject to change based on the batch execution strategy].
- The function
must return a pattern of Tensors, or objects
convertible to Tensors.
- The returned pattern must be compatible with the Type
s of
vars_out
.
- The Tensors in the returned pattern must have dtype
and shape
compatible with the corresponding TensorType
s of vars_out
.
- The returned Tensors will be broadcast into their respective
positions if necessary. The broadcasting includes the batch
dimension: Thus, a returned Tensor of insufficient rank (e.g., a
constant) will be broadcast across batch members. In particular,
a Tensor that carries the indended batch size but whose sub-batch
shape is too low rank will broadcast incorrectly, and will result
in an error.
- If the function
raises an exception, it will propagate and abort
the entire computation.
- Even in the TensorFlow backend, the function
will be staged
several times: at least twice during type inference (to ascertain
the shapes of the Tensors it likes to return, as a function of the
shapes of the Tensors it is given), and exactly once during
executable graph construction.
Args | |
---|---|
vars_in
|
list of strings. The names of the VM variables whose
current values to pass to the function .
|
vars_out
|
Pattern of strings. The names of the VM variables
where to save the results returned from function .
|
function
|
Python callable implementing the computation. |
skip_push_mask
|
Set of strings, a subset of vars_out . These VM variables
will be updated in place rather than pushed.
|
Attributes | |
---|---|
vars_in
|
|
vars_out
|
|
function
|
|
skip_push_mask
|
replace
replace(
vars_out=None
)
Return a copy of self
with vars_out
replaced.