J'utilise RandomForest.featureImportances
mais je ne comprends pas le résultat de sortie.
J'ai 12 fonctionnalités, et c'est la sortie que j'obtiens.
Je comprends que ce n'est peut-être pas une question spécifique à Apache-Spark, mais je ne trouve nulle part qui explique la sortie.
// org.apache.spark.mllib.linalg.Vector = (12,[0,1,2,3,4,5,6,7,8,9,10,11],
[0.1956128039688559,0.06863606797951556,0.11302128590305296,0.091986700351889,0.03430651625283274,0.05975817050022879,0.06929766152519388,0.052654922125615934,0.06437052114945474,0.1601713590349946,0.0324327322375338,0.057751258970832206])
2 réponses
Étant donné un modèle d'ensemble d'arbres, RandomForest.featureImportances
calcule l'importance de chaque caractéristique.
Cela généralise l'idée de l'importance de "Gini" à d'autres pertes, suite à l'explication de l'importance de Gini de la documentation "Random Forests" par Leo Breiman et Adele Cutler, et suite à la mise en œuvre de scikit-learn.
Pour les collections d'arbres, ce qui comprend le renforcement et l'ensachage, Hastie et al. suggère d'utiliser la moyenne de l'importance d'un seul arbre sur tous les arbres de l'ensemble.
Et l'importance de cette caractéristique est calculée comme suit :
- Moyenne sur les arbres :
- importance (caractéristique j) = somme (sur les nœuds qui se séparent sur la caractéristique j) du gain, où le gain est mis à l'échelle par le nombre d'instances passant par le nœud
- Normaliser les importances de l'arbre à additionner à 1.
- Normaliser le vecteur d'importance des caractéristiques pour qu'il soit égal à 1.
Références : Hastie, Tibshirani, Friedman . "Les éléments de l'apprentissage statistique, 2e édition." 2001. - 15.3.2 Importance variable page 593.
Revenons à votre vecteur d'importance :
val importanceVector = Vectors.sparse(12,Array(0,1,2,3,4,5,6,7,8,9,10,11), Array(0.1956128039688559,0.06863606797951556,0.11302128590305296,0.091986700351889,0.03430651625283274,0.05975817050022879,0.06929766152519388,0.052654922125615934,0.06437052114945474,0.1601713590349946,0.0324327322375338,0.057751258970832206))
Tout d'abord, trions ces fonctionnalités par importance :
importanceVector.toArray.zipWithIndex
.map(_.swap)
.sortBy(-_._2)
.foreach(x => println(x._1 + " -> " + x._2))
// 0 -> 0.1956128039688559
// 9 -> 0.1601713590349946
// 2 -> 0.11302128590305296
// 3 -> 0.091986700351889
// 6 -> 0.06929766152519388
// 1 -> 0.06863606797951556
// 8 -> 0.06437052114945474
// 5 -> 0.05975817050022879
// 11 -> 0.057751258970832206
// 7 -> 0.052654922125615934
// 4 -> 0.03430651625283274
// 10 -> 0.0324327322375338
Alors qu'est-ce que cela signifie ?
Cela signifie que votre première caractéristique (index 0) est la caractéristique la plus importante avec un poids de ~ 0,19 et votre 11e caractéristique (index 10) est la moins importante dans votre modèle.
En complément de la réponse précédente :
L'un des problèmes auxquels j'ai été confronté était de vider le résultat sous la forme de (featureName,Importance) en tant que csv.On peut obtenir les métadonnées pour le vecteur d'entrée des fonctionnalités comme
val featureMetadata = predictions.schema("features").metadata
Voici la structure json de ces métadonnées :
{
"ml_attr": {
"attrs":
{"numeric":[{idx:I,name:N},...],
"nominal":[{vals:V,idx:I,name:N},...]},
"num_attrs":#Attr
}
}
}
Code pour extraire l'importance :
val attrs =featureMetadata.getMetadata("ml_attr").getMetadata("attrs")
val f: (Metadata) => (Long,String) = (m => (m.getLong("idx"), m.getString("name")))
val nominalFeatures= attrs.getMetadataArray("nominal").map(f)
val numericFeatures = attrs.getMetadataArray("numeric").map(f)
val features = (numericFeatures ++ nominalFeatures).sortBy(_._1)
val fImportance = pipeline.stages.filter(_.uid.startsWith("rfc")).head.asInstanceOf[RandomForestClassificationModel].featureImportances.toArray.zip(features).map(x=>(x._2._2,x._1)).sortBy(-_._2)
//Save It now
sc.parallelize(fImportance.toSeq, 1).map(x => s"${x._1},${x._2}").saveAsTextFile(fPath)
Questions connexes
Questions liées
De nouvelles questions
apache-spark
Apache Spark est un moteur de traitement de données distribué open source écrit en Scala fournissant une API unifiée et des ensembles de données distribués aux utilisateurs pour le traitement par lots et en continu. Les cas d'utilisation d'Apache Spark sont souvent liés à l'apprentissage automatique / profond, au traitement des graphiques.