StringIndexer

StringIndexer encodes a string column of labels to a column of label indices. The indices are in [0, numLabels), and four ordering options are supported: "frequencyDesc": descending order by label frequency (most frequent label assigned 0), "frequencyAsc": ascending order by label frequency (least frequent label assigned 0), "alphabetDesc": descending alphabetical order, and "alphabetAsc": ascending alphabetical order (default = "frequencyDesc"). The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as Estimator or Transformer make use of this string-indexed label, you must set the input column of the component to this string-indexed column name. In many cases, you can set the input column with setInputCol.

Examples

Assume that we have the following DataFrame with columns id and category:

id

category

0

a

1

b

2

c

3

a

4

a

5

c

category is a string column with three labels: "a", "b", and "c". Applying StringIndexer with category as the input column and categoryIndex as the output column, we should get the following:

id

category

categoryIndex

0

a

0.0

1

b

2.0

2

c

1.0

3

a

0.0

4

a

0.0

5

c

1.0

"a" gets index 0 because it is the most frequent, followed by "c" with index 1 and "b" with index 2.

Additionally, there are three strategies regarding how StringIndexer will handle unseen labels when you have fit a StringIndexer on one dataset and then use it to transform another:

throw an exception (which is the default) skip the row containing the unseen label entirely put unseen labels in a special additional bucket, at index numLabels

import org.apache.spark.ml.feature.StringIndexer

val df = spark.createDataFrame(
  Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
).toDF("id", "category")

val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")

val indexed = indexer.fit(df).transform(df)
indexed.show()

/*
output:
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
|  5|       c|          1.0|
+---+--------+-------------+



*/

Last updated