GaussianMixtureModel

class pyspark.mllib.clustering.GaussianMixtureModel(java_model)[source]

A clustering model derived from the Gaussian Mixture Model method.

>>> from pyspark.mllib.linalg import Vectors, DenseMatrix
>>> from numpy.testing import assert_equal
>>> from shutil import rmtree
>>> import os, tempfile
>>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
...                                         0.9,0.8,0.75,0.935,
...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
...                                 maxIterations=50, seed=10)
>>> labels = model.predict(clusterdata_1).collect()
>>> labels[0]==labels[1]
False
>>> labels[1]==labels[2]
False
>>> labels[4]==labels[5]
True
>>> model.predict([-0.1,-0.05])
0
>>> softPredicted = model.predictSoft([-0.1,-0.05])
>>> abs(softPredicted[0] - 1.0) < 0.03
True
>>> abs(softPredicted[1] - 0.0) < 0.03
True
>>> abs(softPredicted[2] - 0.0) < 0.03
True
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = GaussianMixtureModel.load(sc, path)
>>> assert_equal(model.weights, sameModel.weights)
>>> mus, sigmas = list(
...     zip(*[(g.mu, g.sigma) for g in model.gaussians]))
>>> sameMus, sameSigmas = list(
...     zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
>>> mus == sameMus
True
>>> sigmas == sameSigmas
True
>>> from shutil import rmtree
>>> try:
...     rmtree(path)
... except OSError:
...     pass
>>> data =  array([-5.1971, -2.5359, -3.8220,
...                -5.2211, -5.0602,  4.7118,
...                 6.8989, 3.4592,  4.6322,
...                 5.7048,  4.6567, 5.5026,
...                 4.5605,  5.2043,  6.2734])
>>> clusterdata_2 = sc.parallelize(data.reshape(5,3))
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
...                               maxIterations=150, seed=4)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]
True
>>> labels[2]==labels[3]==labels[4]
True

New in version 1.3.0.

Methods

Attributes

Methods Documentation

call(name, *a)

Call method of java_model

classmethod load(sc, path)[source]

Load the GaussianMixtureModel from disk.

Parameters
  • sc – SparkContext.

  • path – Path to where the model is stored.

New in version 1.5.0.

predict(x)[source]

Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model.

Parameters

x – A feature vector or an RDD of vectors representing data points.

Returns

Predicted cluster label or an RDD of predicted cluster labels if the input is an RDD.

New in version 1.3.0.

predictSoft(x)[source]

Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components.

Parameters

x – A feature vector or an RDD of vectors representing data points.

Returns

The membership value to all mixture components for vector ‘x’ or each vector in RDD ‘x’.

New in version 1.3.0.

save(sc, path)

Save this model to the given path.

New in version 1.3.0.

Attributes Documentation

gaussians

Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i.

New in version 1.4.0.

k

Number of gaussians in mixture.

New in version 1.4.0.

weights

Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1.

New in version 1.4.0.