机器学习中的小批次数据生成器,上代码以备以后使用:

#inputs, targets就是 X 和 y
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]

用法:

>>>X = np.random.randn(6,4)
>>>y = np.random.randint(0,2,6)
>>>gen = iterate_minibatches(X, y, 10, shuffle=False)
>>>next(gen)
Out[113]: 
(array([[-1.29894375,  1.64519902,  0.449616  ,  0.2820526 ],
        [ 0.49223679,  0.4552031 ,  0.40525815,  0.22591931]]), array([0, 0]))

>>>next(gen)
Out[114]: 
(array([[ 1.46478028,  0.26404314,  0.79638976,  0.223665  ],
        [ 0.08203239,  1.52025888,  1.2913706 ,  0.38070632]]), array([0, 0]))

>>>next(gen)
Out[115]: 
(array([[-0.53161078, -0.3045918 ,  0.63480717, -0.12489774],
        [-0.41806273,  0.09459128, -0.4725088 , -0.83306486]]), array([1, 0]))

>>>next(gen)
Traceback (most recent call last):

  File "<ipython-input-116-b2c61ce5e131>", line 1, in <module>
    gen.next()

StopIteration

意思就是把所有数据都迭代完后,迭代器就作废了,如果想要实现循环迭代,得再加一句while True。
循环迭代版:

def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    while True:
        if shuffle:
            indices = np.arange(len(inputs))
            np.random.shuffle(indices)
        for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
            if shuffle:
                excerpt = indices[start_idx:start_idx + batchsize]
            else:
                excerpt = slice(start_idx, start_idx + batchsize)
            yield inputs[excerpt], targets[excerpt]
Logo

讨论HarmonyOS开发技术,专注于API与组件、DevEco Studio、测试、元服务和应用上架分发等。

更多推荐