Day 7: The Treachery of Whales

This problem asks us, in effect, to find the number that minimizes the sum of particular measures of error. (Since the number of points is fixed, we will use “sum” and “mean” of the errors interchangeably, as they merely differ by a constant factor, the number of points.) A naive solution would iterate through all possibilities, but knowledge of statistics offers us a shortcut. In both parts, we aim to minimize an lpl_p-norm.


Setup

Reading in the input:

use crate::{utils::abs_diff, Answer};

fn read_input(s: &str) -> Option<Vec<usize>> {
	s.trim()
		.split(',')
		.map(|n| n.parse().ok())
		.collect::<Option<Vec<_>>>()
}

fn ans_for_input(input: &str) -> Answer<usize, usize> {
	let nums = read_input(input).unwrap();
	(7, (pt1(&nums), pt2(&nums))).into()
}

pub fn ans() -> Answer<usize, usize> {
	ans_for_input(include_str!("input.txt"))
}

Part 1

Part 1 asks us, in effect, to find the number xx that minimizes the mean absolute deviation, or the l1l_1-norm of the error. The number that does this is simply the median of the dataset. (In this problem, the median may be a half-integer, but the solution has to be an integer, so we can just round the median down.)

Proof: If you are not at the median and move towards it, you are moving toward at least as many data points as you are moving away from, which at best decreases the mean absolute deviation and at worst leaves it unchanged.

fn pt1<V: AsRef<[usize]>>(nums: V) -> usize {
	let mut nums = nums.as_ref().to_vec();
	nums.sort_unstable();
	let datum_below = nums[nums.len() / 2];
	let datum_above = nums[1 + (nums.len() - 1) / 2];
	let median = (datum_below + datum_above) / 2;
	nums.iter().map(|&n| abs_diff(n, median)).sum()
}

Part 2

In Part 2, the cost associated with a distance of nn is k=1nk=n(n+1)2\sum_{k=1}^n k = \frac{n(n+1)}{2}. Since nn is an integer, n2nn^2\ge n, and so n2+nn^2+n is “not relatively penalized more than” n2n^2 would be, which means the mean of the n(n+1)2\frac{n(n+1)}{2} will be minimized whenever the mean of n2n^2 is minimized. (It would be a problem if nn exceeded n2n^2, as then it might dominate in the sum and we’d no longer be looking to minimize the mean of the n2n^2.) This is the mean squared error (MSE, or l2l_2-norm of the error), and it is a well-known fact of statistics that the MSE is minimized by the arithmetic mean of the data.

The only catch is that the arithmetic mean need not be an integer, but the solution to the problem must be. Naturally we simply try the two integers on either side of the arithmetic mean (which will both be the arithmetic mean itself if it’s an integer) and pick whichever of the two leads to a smaller error. This works because the error is concave up, i.e., a local minimum is the global minimum.

fn pt2<V: AsRef<[usize]>>(nums: V) -> usize {
	fn cost(mean: usize, nums: &[usize]) -> usize {
		nums.iter()
			.map(|&n| {
				let diff = abs_diff(n, mean);
				diff * (diff + 1) / 2
			})
			.sum()
	}

	let nums = nums.as_ref();
	let sum = nums.iter().sum::<usize>();
	let len = nums.len();

	let mean_rounded_down = sum / len;

	if sum % len == 0 {
		cost(mean_rounded_down, nums)
	} else {
		let mean_rounded_up = (sum - 1) / len + 1;
		cost(mean_rounded_down, nums).min(cost(mean_rounded_up, nums))
	}
}

This page was built using Antora with a theme forked from the default UI. Search is powered by Lunr.

The source code for this UI is licensed under the terms of the MPL-2.0 license.