云计算 waitig 525℃ 百度已收录 0评论

def fit(self, X, y, sample_weight=None)


Build a forest of trees from the training set (X, y).

    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The training input samples. Internally, its dtype will be converted to
        ``dtype=np.float32``. If a sparse matrix is provided, it will be
        converted into a sparse ``csc_matrix``.
    y : array-like, shape = [n_samples] or [n_samples, n_outputs]
        The target values (class labels in classification, real numbers in
    sample_weight : array-like, shape = [n_samples] or None
        Sample weights. If None, then samples are equally weighted. Splits
        that would create child nodes with net zero or negative weight are
        ignored while searching for a split in each node. In the case of
        classification, splits are also ignored if they would result in any
        single class carrying a negative weight in either child node.

    self : object
        Returns self.

以下是 sklearn.ensemble 库中随机森林分类算法的 fit() 函数的具体实现。注释部分是我添加的内容,没有被注释的是程序的源代码。

def fit(self, X, y, sample_weight=None):
    # 通过check_array()函数检查输入数据的有效性并作适当的处理
    # check_array() : Input validation on an array, list, sparse matrix or similar.
    # Validate or convert input data
    X = check_array(X, accept_sparse="csc", dtype=DTYPE)
    y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
    # 对数据集做预处理
    if sample_weight is not None:
        sample_weight = check_array(sample_weight, ensure_2d=False)
    if issparse(X):
        # Pre-sort indices to avoid that each individual tree of the
        # ensemble sorts the indices.

    # 输入样例X的两个维度分别的代表了输入样例数和特征数
    # Remap output
    n_samples, self.n_features_ = X.shape

    # 处理类标签,atleast_1d() 函数作用如下
    # np.atleast_1d(*arys) : Convert inputs to arrays with at least one dimension.
    # np.atleast_1d(1, [3, 4]) ==> [array([1]), array([3, 4])]
    y = np.atleast_1d(y)

    # numpy.ndarray.ndim : Number of array dimensions.
    if y.ndim == 2 and y.shape[1] == 1:
        warn("A column-vector y was passed when a 1d array was"
             " expected. Please change the shape of y to "
             "(n_samples,), for example using ravel().",
             DataConversionWarning, stacklevel=2)

    # numpy.reshape : Gives a new shape to an array without changing its data.
    if y.ndim == 1:
        # reshape is necessary to preserve the data contiguity against vs
        # [:, np.newaxis] that does not.
        y = np.reshape(y, (-1, 1))

    # 输出类别标签的个数,单类别输出问题中值为1,即一个类别标签
    self.n_outputs_ = y.shape[1]

    y, expanded_class_weight = self._validate_y_class_weight(y)

    # getattr()函数获取y的"dtype"属性,属性不存在的时候返回给定的默认值None
    # numpy.ascontiguousarray : Return a contiguous array in memory (C order).
    # numpy.ndarray.flags : Information about the memory layout of the array.
    # from ..tree._tree import DTYPE, DOUBLE
    # from numpy import float32 as DTYPE
    # from numpy import float64 as DOUBLE
    if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
        y = np.ascontiguousarray(y, dtype=DOUBLE)

    if expanded_class_weight is not None:
        if sample_weight is not None:
            sample_weight = sample_weight * expanded_class_weight
            sample_weight = expanded_class_weight

    # class BaseEnsemble :def _validate_estimator(self, default=None)
    # Check the estimator and the n_estimator attribute, set the `base_estimator_` attribute.
    # Check parameters

    if not self.bootstrap and self.oob_score:
        raise ValueError("Out of bag estimation only available"
                         " if bootstrap=True")

    # from sklearn.utils import check_random_state
    # ==> def check_random_state(seed)
    # """Turn seed into a np.random.RandomState instance."""
    # 设置随机器对象
    random_state = check_random_state(self.random_state)

    if not self.warm_start or not hasattr(self, "estimators_"):
        # Free allocated memory, if any
        self.estimators_ = []

    n_more_estimators = self.n_estimators - len(self.estimators_)

    if n_more_estimators < 0:
        raise ValueError('n_estimators=%d must be larger or equal to '
                         'len(estimators_)=%d when warm_start==True'
                         % (self.n_estimators, len(self.estimators_)))

    elif n_more_estimators == 0:
        warn("Warm-start fitting without increasing n_estimators does not "
             "fit new trees.")
        if self.warm_start and len(self.estimators_) > 0:
            # We draw from the random state to get the random state we
            # would have got if we hadn't used a warm_start.
            random_state.randint(MAX_INT, size=len(self.estimators_))

        trees = []
        for i in range(n_more_estimators):
            # sklearn.ensemble ==> -- class BaseEnsemble
            # def _make_estimator(self, append=True, random_state=None)
            # """Make and configure a copy of the `base_estimator_` attribute.
            #   Warning: This method should be used to properly instantiate new
            #   sub-estimators."""
            # return estimator
            tree = self._make_estimator(append=False,

        # 将所有参数设置好之后,可以采用并行运算构建随机森林中的每一个分类器,相关函数作用如下
        # Parallel loop: we use the threading backend as the Cython code
        # for fitting the trees is internally releasing the Python GIL
        # making threading always more efficient than multiprocessing in
        # that case.
        trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
                t, self, X, y, sample_weight, i, len(trees),
                verbose=self.verbose, class_weight=self.class_weight)
            for i, t in enumerate(trees))
        # enumerate()是python的内置函数,可以同时获得索引和值
        # from ..externals.joblib import Parallel, delayed
        # ==> class Parallel 、 def delayed
        # def delayed(function, check_pickle=True):
        #   """Decorator used to capture the arguments of a function.
        #   Pass `check_pickle=False` when:
        #   - performing a possibly repeated check is too costly and has been done
        #     already once outside of the call to delayed.
        #   - when used in conjunction `Parallel(backend='threading')`.
        #   """
        # class Parallel(Logger):
        #   ''' Helper class for readable parallel mapping.

        # Collect newly grown trees

    if self.oob_score:
        self._set_oob_score(X, y)

    # Decapsulate classes_ attributes
    if hasattr(self, "classes_") and self.n_outputs_ == 1:
        self.n_classes_ = self.n_classes_[0]
        self.classes_ = self.classes_[0]

    return self

点赞 (0)分享 (0)