Explore Enterprise Education Gitee Premium Gitee AI AI teammates
Fetch the repository succeeded.
Donate
Please sign in before you donate.
Scan WeChat QR to Pay
Cancel
Complete
Prompt
Switch to Alipay.
OK
Cancel
1 Star 0 Fork 9

bugman/ML.NET

Create your Gitee Account
Explore and code with more than 14 million developers,Free private repositories !:)
Sign up
Already have an account? Sign in
文件
master
Branches (19)
Tags (13)
master
features/automl
features/IntegrationPackage
release/1.0
eerhardt-patch-1
release/preview
ccbuilddef2
ccbuilddef
montebhoover-docs
newciLeg
danmosemsft-patch-1
tryparse
CESARDELATORRE-patch-1
fixmerge
OSXCIVSTS
RemoveDemandsWindowsCi
add-yaml-ci
samples
release/v0.1
v1.0.0
v1.0.0-preview
v0.11.0
v0.10.0
v0.9.0
v0.8.0
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
v0.1.0
master
Branches (19)
Tags (13)
master
features/automl
features/IntegrationPackage
release/1.0
eerhardt-patch-1
release/preview
ccbuilddef2
ccbuilddef
montebhoover-docs
newciLeg
danmosemsft-patch-1
tryparse
CESARDELATORRE-patch-1
fixmerge
OSXCIVSTS
RemoveDemandsWindowsCi
add-yaml-ci
samples
release/v0.1
v1.0.0
v1.0.0-preview
v0.11.0
v0.10.0
v0.9.0
v0.8.0
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
v0.1.0
Clone or Download
Clone/Download
Prompt
To download the code, please copy the following command and execute it in the terminal
To ensure that your submitted code identity is correctly recognized by Gitee, please execute the following command.
When using the SSH protocol for the first time to clone or push code, follow the prompts below to complete the SSH configuration.
1 Generate RSA keys.
2 Obtain the content of the RSA public key and configure it in SSH Public Keys
To use SVN on Gitee, please visit the usage guide
When using the HTTPS protocol, the command line will prompt for account and password verification as follows. For security reasons, Gitee recommends configure and use personal access tokens instead of login passwords for cloning, pushing, and other operations.
Username for 'https://gitee.com': userName
Password for 'https://userName@gitee.com': # Private Token
master
Branches (19)
Tags (13)
master
features/automl
features/IntegrationPackage
release/1.0
eerhardt-patch-1
release/preview
ccbuilddef2
ccbuilddef
montebhoover-docs
newciLeg
danmosemsft-patch-1
tryparse
CESARDELATORRE-patch-1
fixmerge
OSXCIVSTS
RemoveDemandsWindowsCi
add-yaml-ci
samples
release/v0.1
v1.0.0
v1.0.0-preview
v0.11.0
v0.10.0
v0.9.0
v0.8.0
v0.7.0
v0.6.0
v0.5.0
v0.4.0
v0.3.0
v0.2.0
v0.1.0
ML.NET
/
src
/
Microsoft.ML.FastTree
/
GamClassification.cs
ML.NET
/
src
/
Microsoft.ML.FastTree
/
GamClassification.cs
GamClassification.cs 13.66 KB
Copy Edit Raw Blame History
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Threading.Tasks;
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Model;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers.FastTree;
[assembly: LoadableClass(GamBinaryTrainer.Summary,
typeof(GamBinaryTrainer), typeof(GamBinaryTrainer.Options),
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
GamBinaryTrainer.UserNameValue,
GamBinaryTrainer.LoadNameValue,
GamBinaryTrainer.ShortName, DocName = "trainer/GAM.md")]
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(GamBinaryModelParameters), null, typeof(SignatureLoadModel),
"GAM Binary Class Predictor",
GamBinaryModelParameters.LoaderSignature)]
namespace Microsoft.ML.Trainers.FastTree
{
/// <summary>
/// The <see cref="IEstimator{TTransformer}"/> for training a binary classification model with generalized additive models (GAM).
/// </summary>
/// <remarks>
/// <format type="text/markdown"><![CDATA[
/// To create this trainer, use [Gam](xref:Microsoft.ML.TreeExtensions.Gam(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,System.String,System.String,System.String,System.Int32,System.Int32,System.Double))
/// or [Gam(Options)](xref:Microsoft.ML.TreeExtensions.Gam(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FastTree.GamBinaryTrainer.Options)).
///
/// [!include[io](~/../docs/samples/docs/api-reference/io-columns-binary-classification.md)]
///
/// ### Trainer Characteristics
/// | | |
/// | -- | -- |
/// | Machine learning task | Binary classification |
/// | Is normalization required? | No |
/// | Is caching required? | No |
/// | Required NuGet in addition to Microsoft.ML | Microsoft.ML.FastTree |
///
/// [!include[algorithm](~/../docs/samples/docs/api-reference/algo-details-gam.md)]
/// ]]>
/// </format>
/// </remarks>
/// <seealso cref="TreeExtensions.Gam(BinaryClassificationCatalog.BinaryClassificationTrainers, string, string, string, int, int, double)"/>
/// <seealso cref="TreeExtensions.Gam(BinaryClassificationCatalog.BinaryClassificationTrainers, GamBinaryTrainer.Options)"/>
/// <seealso cref="Options"/>
public sealed class GamBinaryTrainer :
GamTrainerBase<GamBinaryTrainer.Options,
BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>,
CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
{
/// <summary>
/// Options for the <see cref="GamBinaryTrainer"/> as used in
/// [Gam(Options)](xref:Microsoft.ML.TreeExtensions.Gam(Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers,Microsoft.ML.Trainers.FastTree.GamBinaryTrainer.Options)).
/// </summary>
public sealed class Options : OptionsBase
{
/// <summary>
/// Whether to use derivatives optimized for unbalanced training data.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Should we use derivatives optimized for unbalanced sets", ShortName = "us")]
[TGUI(Label = "Optimize for unbalanced")]
public bool UnbalancedSets = false;
}
internal const string LoadNameValue = "BinaryClassificationGamTrainer";
internal const string UserNameValue = "Generalized Additive Model for Binary Classification";
internal const string ShortName = "gam";
private readonly double _sigmoidParameter;
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
private protected override bool NeedCalibration => true;
/// <summary>
/// Initializes a new instance of <see cref="GamBinaryTrainer"/>
/// </summary>
internal GamBinaryTrainer(IHostEnvironment env, Options options)
: base(env, options, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
{
_sigmoidParameter = 1;
}
/// <summary>
/// Initializes a new instance of <see cref="GamBinaryTrainer"/>
/// </summary>
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="featureColumnName">The name of the feature column.</param>
/// <param name="rowGroupColumnName">The name for the column containing the example weight.</param>
/// <param name="numberOfIterations">The number of iterations to use in learning the features.</param>
/// <param name="learningRate">The learning rate. GAMs work best with a small learning rate.</param>
/// <param name="maximumBinCountPerFeature">The maximum number of bins to use to approximate features</param>
internal GamBinaryTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
string rowGroupColumnName = null,
int numberOfIterations = GamDefaults.NumberOfIterations,
double learningRate = GamDefaults.LearningRate,
int maximumBinCountPerFeature = GamDefaults.MaximumBinCountPerFeature)
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumnName), featureColumnName, rowGroupColumnName, numberOfIterations, learningRate, maximumBinCountPerFeature)
{
_sigmoidParameter = 1;
}
private protected override void CheckLabel(RoleMappedData data)
{
data.CheckBinaryLabel();
}
private static bool[] ConvertTargetsToBool(double[] targets)
{
bool[] boolArray = new bool[targets.Length];
int innerLoopSize = 1 + targets.Length / BlockingThreadPool.NumThreads;
var actions = new Action[(int)Math.Ceiling(1.0 * targets.Length / innerLoopSize)];
var actionIndex = 0;
for (int d = 0; d < targets.Length; d += innerLoopSize)
{
var fromDoc = d;
var toDoc = Math.Min(d + innerLoopSize, targets.Length);
actions[actionIndex++] = () =>
{
for (int doc = fromDoc; doc < toDoc; doc++)
boolArray[doc] = targets[doc] > 0;
};
}
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
return boolArray;
}
private protected override CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
{
TrainBase(context);
var predictor = new GamBinaryModelParameters(Host,
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
return new ValueMapperCalibratedModelParameters<GamBinaryModelParameters, PlattCalibrator>(Host, predictor, calibrator);
}
private protected override ObjectiveFunctionBase CreateObjectiveFunction()
{
return new FastTreeBinaryTrainer.ObjectiveImpl(
TrainSet,
ConvertTargetsToBool(TrainSet.Targets),
GamTrainerOptions.LearningRate,
0,
_sigmoidParameter,
GamTrainerOptions.UnbalancedSets,
GamTrainerOptions.MaximumTreeOutput,
GamTrainerOptions.GetDerivativesSampleRate,
false,
GamTrainerOptions.Seed,
ParallelTraining
);
}
private protected override void DefinePruningTest()
{
var validTest = new BinaryClassificationTest(ValidSetScore,
ConvertTargetsToBool(ValidSet.Targets), _sigmoidParameter);
// As per FastTreeClassification.ConstructOptimizationAlgorithm()
PruningLossIndex = GamTrainerOptions.UnbalancedSets ? 3 /*Unbalanced sets loss*/ : 1 /*normal loss*/;
PruningTest = new TestHistory(validTest, PruningLossIndex);
}
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
MakeTransformer(CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
/// <summary>
/// Trains a <see cref="GamBinaryTrainer"/> using both training and validation data, returns
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
/// </summary>
public BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
=> TrainTransformer(trainData, validationData);
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation())),
new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true))),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false, new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation()))
};
}
}
/// <summary>
/// Model parameters for <see cref="GamBinaryTrainer"/>.
/// </summary>
public sealed class GamBinaryModelParameters : GamModelParametersBase, IPredictorProducing<float>
{
internal const string LoaderSignature = "BinaryClassGamPredictor";
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
/// <summary>
/// Construct a new Binary Classification GAM with the defined properties.
/// </summary>
/// <param name="env">The Host Environment</param>
/// <param name="binUpperBounds">An array of arrays of bin-upper-bounds for each feature.</param>
/// <param name="binEffects">An array of arrays of effect sizes for each bin for each feature.</param>
/// <param name="intercept">The intercept term for the model. Also referred to as the bias or the mean effect.</param>
/// <param name="inputLength">The number of features passed from the dataset. Used when the number of input features is
/// different than the number of shape functions. Use default if all features have a shape function.</param>
/// <param name="featureToInputMap">A map from the feature shape functions, as described by <paramref name="binUpperBounds"/> and <paramref name="binEffects"/>.
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
/// a shape function.</param>
internal GamBinaryModelParameters(IHostEnvironment env,
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength, int[] featureToInputMap)
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }
private GamBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
: base(env, LoaderSignature, ctx) { }
private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "GAM BINP",
// verWrittenCur: 0x00010001, // Initial
// verWrittenCur: 0x00010001, // Added Intercept but collided from release 0.6-0.9
verWrittenCur: 0x00010002, // Added Intercept (version revved to address collisions)
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName);
}
private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
var predictor = new GamBinaryModelParameters(env, ctx);
ICalibrator calibrator;
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
if (calibrator == null)
return predictor;
return new SchemaBindableCalibratedModelParameters<GamBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
}
private protected override void SaveCore(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
base.SaveCore(ctx);
}
}
}
Loading...
Report
Report success
We will send you the feedback within 2 working days through the letter!
Please fill in the reason for the report carefully. Provide as detailed a description as possible.
Please select a report type
Cancel
Send
误判申诉

此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。

如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。

取消
提交

About

ML.NET 是一个跨平台的开源机器学习框架,旨在让 .NET 开发者更快上手机器学习
Cancel

Releases

No release

Contributors

All

Activities

can not load any more
Edit
About
Homepage
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/skygit/ML.NET.git
git@gitee.com:skygit/ML.NET.git
skygit
ML.NET
ML.NET
master
Going to Help Center

Search

Comment
Repository Report
Back to the top
Login prompt
This operation requires login to the code cloud account. Please log in before operating.
Go to login
No account. Register

AltStyle によって変換されたページ (->オリジナル) /