Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Included StringIndexer and StringIndexerModel along with related test... #804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ramanathanv wants to merge 16 commits into dotnet:main
base: main
Choose a base branch
Loading
from ramanathanv:StringIndexer
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
1f4bcdb
Included StringIndexer and StringIndexerModel along with related test...
Jan 6, 2021
65f43d0
Corrected issue in test case
Jan 6, 2021
f3b287c
Corrected issue in test case
Jan 6, 2021
89d1b98
Merge branch 'master' into StringIndexer
ramanathanv Jan 18, 2021
0d04752
Merge branch 'master' into StringIndexer
imback82 Jan 22, 2021
ac68589
Merge branch 'master' into StringIndexer
ramanathanv Jan 29, 2021
f78d4ce
Corrected the test case
Jan 29, 2021
6cd1a7c
Changed FirstorDefault to Where
Feb 1, 2021
6ead393
Modified List datatype
Feb 1, 2021
fa1add4
Corrected the internal property names
Feb 3, 2021
24b3331
Merge branch 'master' into StringIndexer
ramanathanv Feb 3, 2021
643789c
Changed List comparison
Feb 3, 2021
a5007c9
Merge branch 'StringIndexer' of https://github.com/ramanathanv/spark ...
Feb 3, 2021
4cc337f
Reverted direct List Check
Feb 3, 2021
59ea4e5
Merge branch 'master' into StringIndexer
ramanathanv Feb 4, 2021
641e1d3
Merge branch 'master' into StringIndexer
ramanathanv Feb 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
View file Open in desktop
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class StringIndexerModelTests : FeatureBaseTests<StringIndexerModel>
{
private readonly SparkSession _spark;

public StringIndexerModelTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="DataFrame"/>, create a <see cref="StringIndexerModel"/> and test the
/// available methods.
/// </summary>
[Fact]
public void TestStringIndexerModel()
{
DataFrame input = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] {0, "a"}),
new GenericRow(new object[] {1, "b"}),
new GenericRow(new object[] {2, "c"}),
new GenericRow(new object[] {3, "a"}),
new GenericRow(new object[] {4, "a"}),
new GenericRow(new object[] {5, "c"})
},
new StructType(new List<StructField>
{
new StructField("id", new IntegerType()),
new StructField("category", new StringType())
}));

string expectedUid = "theUid";
StringIndexer stringIndexer = new StringIndexer(expectedUid)
.SetInputCol("category")
.SetOutputCol("categoryIndex");

StringIndexerModel stringIndexerModel = stringIndexer.Fit(input);
DataFrame transformedDF = stringIndexerModel.Transform(input);
List<Row> observed = transformedDF.Select("category", new string[] { "categoryIndex" })
.Collect().ToList();
List<Row> expected = new List<Row>
{
new Row(new GenericRow(new object[] {"a", "0"})),
new Row(new GenericRow(new object[] {"b", "2"})),
new Row(new GenericRow(new object[] {"c", "1"})),
new Row(new GenericRow(new object[] {"a", "0"})),
new Row(new GenericRow(new object[] {"a", "0"})),
new Row(new GenericRow(new object[] {"c", "1"}))
};

observed.ForEach(a =>
{
Assert.Equal(a, expected.Where(b => b == a).FirstOrDefault());
}
);
Assert.Equal("category", stringIndexer.GetInputCol());
Assert.Equal("categoryIndex", stringIndexer.GetOutputCol());
Assert.Equal(expectedUid, stringIndexer.Uid());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "stringIndexerModel");
stringIndexerModel.Save(savePath);

StringIndexerModel loadedModel = StringIndexerModel.Load(savePath);
Assert.Equal(stringIndexerModel.Uid(), loadedModel.Uid());
}
}
}
}
View file Open in desktop
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class StringIndexerTests : FeatureBaseTests<StringIndexer>
{
private readonly SparkSession _spark;

public StringIndexerTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="DataFrame"/>, create a <see cref="StringIndexer"/> and test the
/// available methods.
/// </summary>
[Fact]
public void TestStringIndexer()
{
string expectedUid = "theUid";
StringIndexer stringIndexer = new StringIndexer(expectedUid)
.SetInputCol("category")
.SetOutputCol("categoryIndex");

Assert.Equal("category", stringIndexer.GetInputCol());
Assert.Equal("categoryIndex", stringIndexer.GetOutputCol());
Assert.Equal(expectedUid, stringIndexer.Uid());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "stringIndexer");
stringIndexer.Save(savePath);

StringIndexer loadedstringIndexer = StringIndexer.Load(savePath);
Assert.Equal(stringIndexer.Uid(), loadedstringIndexer.Uid());
}
}
}
}
174 changes: 174 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Feature/StringIndexer.cs
View file Open in desktop
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;

namespace Microsoft.Spark.ML.Feature
{
/// <summary>
/// <see cref="StringIndexer"/> encodes a string column of labels to a column of label indices.
/// </summary>
public class StringIndexer : FeatureBase<StringIndexer>, IJvmObjectReferenceProvider
{
private static readonly string s_StringIndexerClassName =
"org.apache.spark.ml.feature.StringIndexer";

/// <summary>
/// Create a <see cref="StringIndexer"/> without any parameters.
/// </summary>
public StringIndexer() : base(s_StringIndexerClassName)
{
}

/// <summary>
/// Create a <see cref="StringIndexer"/> with a UID that is used to give the
/// <see cref="StringIndexer"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public StringIndexer(string uid) : base(s_StringIndexerClassName, uid)
{
}

internal StringIndexer(JvmObjectReference jvmObject) : base(jvmObject)
{
}

JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;

/// <summary>
/// Executes the <see cref="StringIndexer"/> and transforms the schema.
/// </summary>
/// <param name="value">The Schema to be transformed</param>
/// <returns>
/// New <see cref="StructType"/> object with the schema <see cref="StructType"/> transformed.
/// </returns>
public StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)_jvmObject.Invoke(
"transformSchema",
DataType.FromJson(_jvmObject.Jvm, value.Json)));

/// <summary>
/// Executes the <see cref="StringIndexer"/> and fits a model to the input data.
/// </summary>
/// <param name="source">The <see cref="DataFrame"/> to fit the model to.</param>
/// <returns><see cref="StringIndexerModel"/></returns>
public StringIndexerModel Fit(DataFrame source) =>
new StringIndexerModel((JvmObjectReference)_jvmObject.Invoke("fit", source));

/// <summary>
/// Gets the HandleInvalid.
/// </summary>
/// <returns>Handle Invalid option</returns>
public string GetHandleInvalid() => (string)_jvmObject.Invoke("getHandleInvalid");

/// <summary>
/// Sets the Handle Invalid option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="handleInvalid">Handle Invalid option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Handle Invalid set.
/// </returns>
public StringIndexer SetHandleInvalid(string handleInvalid) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setHandleInvalid", handleInvalid));

/// <summary>
/// Gets the InputCol.
/// </summary>
/// <returns>Input Col option</returns>
public string GetInputCol() => (string)_jvmObject.Invoke("getInputCol");

/// <summary>
/// Sets the Input Col option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="inputCol">Input Col option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Input Col set.
/// </returns>
public StringIndexer SetInputCol(string inputCol) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCol", inputCol));

/// <summary>
/// Gets the InputCols array.
/// </summary>
/// <returns>Input Cols array option</returns>
public string[] GetInputCols() => (string[])_jvmObject.Invoke("getInputCols");

/// <summary>
/// Sets the Input Cols array option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="inputCols">Input Cols array option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Input Cols array set.
/// </returns>
public StringIndexer SetInputCols(string[] inputCols) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setInputCols", inputCols));

/// <summary>
/// Gets the OutputCol.
/// </summary>
/// <returns>Output Col option</returns>
public string GetOutputCol() => (string)_jvmObject.Invoke("getOutputCol");

/// <summary>
/// Sets the Output Col option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="outputCol">Output Col option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Output Col set.
/// </returns>
public StringIndexer SetOutputCol(string outputCol) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCol));

/// <summary>
/// Gets the OutputCols array.
/// </summary>
/// <returns>Output Cols array option</returns>
public string[] GetOutputCols() => (string[])_jvmObject.Invoke("getOutputCols");

/// <summary>
/// Sets the Output Cols array option to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="outputCols">Output Cols array option</param>
/// <returns>
/// <see cref="StringIndexer"/> with the Output Cols array set.
/// </returns>
public StringIndexer SetOutputCols(string[] outputCols) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setOutputCol", outputCols));

/// <summary>
/// Gets the String Order Type.
/// </summary>
/// <returns>String Order Type</returns>
public string GetStringOrderType() => (string)_jvmObject.Invoke("getStringOrderType");

/// <summary>
/// Sets the String Order Type to <see cref="StringIndexer"/>.
/// </summary>
/// <param name="stringOrderType">String Order Type</param>
/// <returns>
/// <see cref="StringIndexer"/> with the String Order Type set.
/// </returns>
public StringIndexer SetStringOrderType(string stringOrderType) =>
WrapAsStringIndexer((JvmObjectReference)_jvmObject.Invoke("setStringOrderType", stringOrderType));

/// <summary>
/// Loads the <see cref="StringIndexer"/> that was previously saved using Save.
/// </summary>
/// <param name="path">The path the previous <see cref="StringIndexer"/> was saved to</param>
/// <returns>New <see cref="StringIndexer"/> object, loaded from path</returns>
public static StringIndexer Load(string path) =>
WrapAsStringIndexer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_StringIndexerClassName,
"load",
path));

private static StringIndexer WrapAsStringIndexer(object obj) =>
new StringIndexer((JvmObjectReference)obj);
}
}
Loading

AltStyle によって変換されたページ (->オリジナル) /