arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters
- idata
InferenceData
orInferenceData_like
InferenceData from which to extract the data.
- group
str
, optional Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_names
str
orlist
ofstr
, optional Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samples
int
, optional Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool,
int
,numpy.Generator
, optional Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- idata
- Returns
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> array([[ 1.01866546e+01, 7.25006163e+00, 3.17378321e+00, 9.44361810e+00, 6.83541506e+00, 2.22679229e+00, 2.55728409e+00, 4.44255358e+00, 7.29402025e+00, 1.69134963e+01, 7.52662516e+00, 1.00503087e+00, 7.56338176e+00, 8.68766223e+00, 1.04382282e+01, 7.25006163e+00, 9.05381401e+00, 7.09929855e+00, 7.48425135e+00, 1.69945398e+01, 5.36521047e+00, 4.01471367e+00, 5.35480976e+00, 8.33941061e+00, 3.34359524e+00, -3.72281402e+00, 1.01448164e+01, 3.34359524e+00, -1.45575080e+00, 1.18989120e+01, 1.99712736e+01, 1.21133473e+01, 6.72368373e+00, 6.47907499e+00, 1.31186950e+00, 2.68736677e+00, 1.24760716e+01, 3.51385523e+00, 1.19155656e+01, 4.58205228e+00, 5.17287960e+00, 4.66166341e+00, 6.88803121e+00, 1.59722640e+01, 6.72461144e+00, 1.43710361e+01, -1.62714835e+00, 1.33292138e+01, 2.01472848e+01, 8.66311246e+00, 3.04243417e+00, 1.67363399e+00, 4.53814410e+00, 1.43655881e+01, 1.56183815e+00, 3.98767995e+00, 6.11572170e+00, 4.30549864e+00, 7.35127382e+00, 1.33067636e+01, ... 1.09491930e+00, 1.40369849e+01, 5.48779218e+00, 6.00954901e+00, 8.04155485e+00, 7.30179078e+00, -3.19486996e-01, 8.71059098e+00, 4.06507329e+00, 3.23104629e+00, 4.47140798e-03, 2.00074480e+00, 6.43142446e+00, 2.00138205e+00, 4.68006899e+00, 3.36798595e+00, 9.39143626e+00, 2.71792506e+00, 1.43019370e+01, 9.21082747e+00, 9.68439852e+00, -3.18008755e+00, 5.04912122e+00, 5.29939272e+00, 5.23985192e+00, 7.44372537e+00, 1.58339992e+01, 8.38931576e+00, 3.65105149e+00, 7.31666261e+00, 2.72514770e+00, -9.83216952e-01, 3.35132225e+00, 7.06699293e+00, -1.48937601e+00, 1.54421419e-01, 9.71725690e+00, 2.56174124e+00, 7.71820976e+00, 1.16705124e+01, 7.09633148e+00, 3.55789284e+01, -1.76985435e+00, 2.80204694e+00, 6.55530380e+00, 3.35132225e+00, 1.52230402e+01, 4.59893074e+00, 3.83319884e+00, -5.93250093e-01, 1.74863467e+01, 1.54844421e+00, 3.74670827e+00, 1.25109351e-01, 2.89576872e+00, -2.25684039e+00, 7.91874973e-01, 4.22582989e+00]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 3 3 3 3 0 1 1 3 2 1 1 ... 0 0 2 3 3 3 2 2 3 1 3 1 * draw (sample) int64 11 199 280 45 442 79 88 ... 346 113 372 26 166 154
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.602116 inference_library: pymc inference_library_version: 4.2.2