Bagaimana menemukan rata -rata dan standar deviasi dataset trqiing di Pytorch

def mean_std(loader):
    mean = 0.0
    std = 0.0
    total_images_count = 0
    for images, _ in loader:
        image_count_in_a_batch = images.size(0)
        images = images.view(image_count_in_a_batch, images.size(1), -1)
        mean += (images * 1.0).mean(2).sum(0)
        std += (images * 1.0).std(2).sum(0)
        total_images_count += image_count_in_a_batch
    mean /= total_images_count
    std /= total_images_count
    return mean, std

(tensor([112.1058, 126.2224,  79.7943]), tensor([49.5487, 50.0356, 43.0908]))
Determined Dotterel