flax.linen.get_sharding# flax.linen.get_sharding(tree, mesh)[source]# Extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh. Parameters: tree (Any) – mesh (Mesh) – Return type: Any