GaussianMixtureModel¶
-
class
pyspark.mllib.clustering.
GaussianMixtureModel
(java_model)[source]¶ A clustering model derived from the Gaussian Mixture Model method.
>>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> from numpy.testing import assert_equal >>> from shutil import rmtree >>> import os, tempfile
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] False >>> labels[4]==labels[5] True >>> model.predict([-0.1,-0.05]) 0 >>> softPredicted = model.predictSoft([-0.1,-0.05]) >>> abs(softPredicted[0] - 1.0) < 0.03 True >>> abs(softPredicted[1] - 0.0) < 0.03 True >>> abs(softPredicted[2] - 0.0) < 0.03 True
>>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = GaussianMixtureModel.load(sc, path) >>> assert_equal(model.weights, sameModel.weights) >>> mus, sigmas = list( ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) >>> sameMus, sameSigmas = list( ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) >>> mus == sameMus True >>> sigmas == sameSigmas True >>> from shutil import rmtree >>> try: ... rmtree(path) ... except OSError: ... pass
>>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, ... 5.7048, 4.6567, 5.5026, ... 4.5605, 5.2043, 6.2734]) >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=4) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1] True >>> labels[2]==labels[3]==labels[4] True
New in version 1.3.0.
Methods
Attributes
Methods Documentation
-
call
(name, *a)¶ Call method of java_model
-
classmethod
load
(sc, path)[source]¶ Load the GaussianMixtureModel from disk.
- Parameters
sc – SparkContext.
path – Path to where the model is stored.
New in version 1.5.0.
-
predict
(x)[source]¶ Find the cluster to which the point ‘x’ or each point in RDD ‘x’ has maximum membership in this model.
- Parameters
x – A feature vector or an RDD of vectors representing data points.
- Returns
Predicted cluster label or an RDD of predicted cluster labels if the input is an RDD.
New in version 1.3.0.
-
predictSoft
(x)[source]¶ Find the membership of point ‘x’ or each point in RDD ‘x’ to all mixture components.
- Parameters
x – A feature vector or an RDD of vectors representing data points.
- Returns
The membership value to all mixture components for vector ‘x’ or each vector in RDD ‘x’.
New in version 1.3.0.
-
save
(sc, path)¶ Save this model to the given path.
New in version 1.3.0.
Attributes Documentation
-
gaussians
¶ Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i.
New in version 1.4.0.
-
k
¶ Number of gaussians in mixture.
New in version 1.4.0.
-
weights
¶ Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1.
New in version 1.4.0.
-