|
668 | 668 | },
|
669 | 669 | {
|
670 | 670 | "cell_type": "code",
|
671 | | - "execution_count": 11, |
| 671 | + "execution_count": 18, |
672 | 672 | "metadata": {},
|
673 | 673 | "outputs": [],
|
674 | 674 | "source": [
|
675 | 675 | "from sklearn.model_selection import GridSearchCV\n",
|
676 | 676 | "\n",
|
677 | | - "def get_bset_mdoel_and_accuracy(model, params, X, y):\n", |
| 677 | + "def get_best_model_and_accuracy(model, params, X, y):\n", |
678 | 678 | " grid = GridSearchCV(model, #要搜索的模型\n",
|
679 | 679 | " params, #要尝试的参数\n",
|
680 | 680 | " error_score = 0.)\n",
|
|
695 | 695 | "cell_type": "markdown",
|
696 | 696 | "metadata": {},
|
697 | 697 | "source": [
|
698 | | - "#" |
| 698 | + "# 1. 选择合适模型" |
699 | 699 | ]
|
| 700 | + }, |
| 701 | + { |
| 702 | + "cell_type": "code", |
| 703 | + "execution_count": 14, |
| 704 | + "metadata": {}, |
| 705 | + "outputs": [], |
| 706 | + "source": [ |
| 707 | + "from sklearn.linear_model import LogisticRegression\n", |
| 708 | + "from sklearn.neighbors import KNeighborsClassifier\n", |
| 709 | + "from sklearn.tree import DecisionTreeClassifier\n", |
| 710 | + "from sklearn.ensemble import RandomForestClassifier\n", |
| 711 | + "\n", |
| 712 | + "\n", |
| 713 | + "# Set up some parameters for our grid search\n", |
| 714 | + "# We will start with four different machine learning models\n", |
| 715 | + "# logistic regression, KNN, Decision Tree, and Random Forest\n", |
| 716 | + "lr_params = {'C':[1e-1, 1e0, 1e1, 1e2], 'penalty':['l1', 'l2']}\n", |
| 717 | + "knn_params = {'n_neighbors': [1, 3, 5, 7]}\n", |
| 718 | + "tree_params = {'max_depth': [None, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]}\n", |
| 719 | + "forest_params = {'n_estimators': [10, 50, 100], 'max_depth': [None, 1, 3, 5, 7]}\n", |
| 720 | + "\n", |
| 721 | + "\n", |
| 722 | + "# instantiate the four machine learning models\n", |
| 723 | + "lr = LogisticRegression()\n", |
| 724 | + "knn = KNeighborsClassifier()\n", |
| 725 | + "d_tree = DecisionTreeClassifier()\n", |
| 726 | + "forest = RandomForestClassifier()" |
| 727 | + ] |
| 728 | + }, |
| 729 | + { |
| 730 | + "cell_type": "code", |
| 731 | + "execution_count": 16, |
| 732 | + "metadata": {}, |
| 733 | + "outputs": [ |
| 734 | + { |
| 735 | + "name": "stderr", |
| 736 | + "output_type": "stream", |
| 737 | + "text": [ |
| 738 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n", |
| 739 | + " warnings.warn(CV_WARNING, FutureWarning)\n", |
| 740 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 741 | + " FutureWarning)\n", |
| 742 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 743 | + " FutureWarning)\n", |
| 744 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 745 | + " FutureWarning)\n", |
| 746 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 747 | + " FutureWarning)\n", |
| 748 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 749 | + " FutureWarning)\n", |
| 750 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 751 | + " FutureWarning)\n", |
| 752 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 753 | + " FutureWarning)\n", |
| 754 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 755 | + " FutureWarning)\n", |
| 756 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 757 | + " FutureWarning)\n", |
| 758 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 759 | + " FutureWarning)\n", |
| 760 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 761 | + " FutureWarning)\n", |
| 762 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 763 | + " FutureWarning)\n", |
| 764 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 765 | + " FutureWarning)\n", |
| 766 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 767 | + " FutureWarning)\n", |
| 768 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 769 | + " FutureWarning)\n", |
| 770 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 771 | + " FutureWarning)\n", |
| 772 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 773 | + " FutureWarning)\n", |
| 774 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 775 | + " FutureWarning)\n", |
| 776 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 777 | + " FutureWarning)\n", |
| 778 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 779 | + " FutureWarning)\n", |
| 780 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 781 | + " FutureWarning)\n", |
| 782 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 783 | + " FutureWarning)\n", |
| 784 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 785 | + " FutureWarning)\n", |
| 786 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 787 | + " FutureWarning)\n", |
| 788 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", |
| 789 | + " FutureWarning)\n" |
| 790 | + ] |
| 791 | + }, |
| 792 | + { |
| 793 | + "name": "stdout", |
| 794 | + "output_type": "stream", |
| 795 | + "text": [ |
| 796 | + "Best Accuracy: 0.8095666666666667\n", |
| 797 | + "Best Parameters: {'C': 0.1, 'penalty': 'l1'}\n", |
| 798 | + "Average Time to Fit (s): 0.474\n", |
| 799 | + "Average Time to Score (s): 0.071\n" |
| 800 | + ] |
| 801 | + } |
| 802 | + ], |
| 803 | + "source": [ |
| 804 | + "get_best_model_and_accuracy(lr, lr_params, X, y)" |
| 805 | + ] |
| 806 | + }, |
| 807 | + { |
| 808 | + "cell_type": "code", |
| 809 | + "execution_count": 19, |
| 810 | + "metadata": {}, |
| 811 | + "outputs": [ |
| 812 | + { |
| 813 | + "name": "stderr", |
| 814 | + "output_type": "stream", |
| 815 | + "text": [ |
| 816 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n", |
| 817 | + " warnings.warn(CV_WARNING, FutureWarning)\n" |
| 818 | + ] |
| 819 | + }, |
| 820 | + { |
| 821 | + "name": "stdout", |
| 822 | + "output_type": "stream", |
| 823 | + "text": [ |
| 824 | + "Best Accuracy: 0.7602333333333333\n", |
| 825 | + "Best Parameters: {'n_neighbors': 7}\n", |
| 826 | + "Average Time to Fit (s): 0.02\n", |
| 827 | + "Average Time to Score (s): 0.796\n" |
| 828 | + ] |
| 829 | + } |
| 830 | + ], |
| 831 | + "source": [ |
| 832 | + "get_best_model_and_accuracy(knn, knn_params, X, y)" |
| 833 | + ] |
| 834 | + }, |
| 835 | + { |
| 836 | + "cell_type": "code", |
| 837 | + "execution_count": 22, |
| 838 | + "metadata": {}, |
| 839 | + "outputs": [ |
| 840 | + { |
| 841 | + "name": "stdout", |
| 842 | + "output_type": "stream", |
| 843 | + "text": [ |
| 844 | + "{'classifier__n_neighbors': [1, 3, 5, 7]}\n" |
| 845 | + ] |
| 846 | + }, |
| 847 | + { |
| 848 | + "name": "stderr", |
| 849 | + "output_type": "stream", |
| 850 | + "text": [ |
| 851 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n", |
| 852 | + " warnings.warn(CV_WARNING, FutureWarning)\n" |
| 853 | + ] |
| 854 | + }, |
| 855 | + { |
| 856 | + "name": "stdout", |
| 857 | + "output_type": "stream", |
| 858 | + "text": [ |
| 859 | + "Best Accuracy: 0.8008\n", |
| 860 | + "Best Parameters: {'classifier__n_neighbors': 7}\n", |
| 861 | + "Average Time to Fit (s): 0.041\n", |
| 862 | + "Average Time to Score (s): 6.52\n" |
| 863 | + ] |
| 864 | + } |
| 865 | + ], |
| 866 | + "source": [ |
| 867 | + "from sklearn.pipeline import Pipeline\n", |
| 868 | + "from sklearn.preprocessing import StandardScaler\n", |
| 869 | + "\n", |
| 870 | + "# construct pipeline parameters based on the parameters\n", |
| 871 | + "# for KNN on its own\n", |
| 872 | + "knn_pipe_params = {'classifier__{}'.format(k): v for k, v in knn_params.items()}\n", |
| 873 | + "print(knn_pipe_params)\n", |
| 874 | + "\n", |
| 875 | + "# KNN requires a standard scalar due to using Euclidean distance as\n", |
| 876 | + "# the main equation for predicting observations\n", |
| 877 | + "knn_pipe = Pipeline([('scale', StandardScaler()), ('classifier', knn)])\n", |
| 878 | + "\n", |
| 879 | + "# quick to fit, very slow to predict\n", |
| 880 | + "get_best_model_and_accuracy(knn_pipe, knn_pipe_params, X, y)" |
| 881 | + ] |
| 882 | + }, |
| 883 | + { |
| 884 | + "cell_type": "code", |
| 885 | + "execution_count": 23, |
| 886 | + "metadata": {}, |
| 887 | + "outputs": [ |
| 888 | + { |
| 889 | + "name": "stderr", |
| 890 | + "output_type": "stream", |
| 891 | + "text": [ |
| 892 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n", |
| 893 | + " warnings.warn(CV_WARNING, FutureWarning)\n" |
| 894 | + ] |
| 895 | + }, |
| 896 | + { |
| 897 | + "name": "stdout", |
| 898 | + "output_type": "stream", |
| 899 | + "text": [ |
| 900 | + "Best Accuracy: 0.8202666666666667\n", |
| 901 | + "Best Parameters: {'max_depth': 3}\n", |
| 902 | + "Average Time to Fit (s): 0.23\n", |
| 903 | + "Average Time to Score (s): 0.003\n" |
| 904 | + ] |
| 905 | + } |
| 906 | + ], |
| 907 | + "source": [ |
| 908 | + "get_best_model_and_accuracy(d_tree, tree_params, X, y)" |
| 909 | + ] |
| 910 | + }, |
| 911 | + { |
| 912 | + "cell_type": "code", |
| 913 | + "execution_count": 24, |
| 914 | + "metadata": {}, |
| 915 | + "outputs": [ |
| 916 | + { |
| 917 | + "name": "stderr", |
| 918 | + "output_type": "stream", |
| 919 | + "text": [ |
| 920 | + "/Users/super/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n", |
| 921 | + " warnings.warn(CV_WARNING, FutureWarning)\n" |
| 922 | + ] |
| 923 | + }, |
| 924 | + { |
| 925 | + "name": "stdout", |
| 926 | + "output_type": "stream", |
| 927 | + "text": [ |
| 928 | + "Best Accuracy: 0.8189\n", |
| 929 | + "Best Parameters: {'max_depth': 7, 'n_estimators': 100}\n", |
| 930 | + "Average Time to Fit (s): 1.01\n", |
| 931 | + "Average Time to Score (s): 0.043\n" |
| 932 | + ] |
| 933 | + } |
| 934 | + ], |
| 935 | + "source": [ |
| 936 | + "get_best_model_and_accuracy(forest, forest_params, X, y)" |
| 937 | + ] |
| 938 | + }, |
| 939 | + { |
| 940 | + "cell_type": "code", |
| 941 | + "execution_count": null, |
| 942 | + "metadata": {}, |
| 943 | + "outputs": [], |
| 944 | + "source": [] |
700 | 945 | }
|
701 | 946 | ],
|
702 | 947 | "metadata": {
|
|
0 commit comments