diff --git a/test/unit/common/test_statsd_client.py b/test/unit/common/test_statsd_client.py index 7790ddc323..f241367499 100644 --- a/test/unit/common/test_statsd_client.py +++ b/test/unit/common/test_statsd_client.py @@ -371,9 +371,21 @@ class TestGetLabeledStatsdClientConfParsing(BaseTestStatsdClient): str(ctx.exception)) -class TestGetStatsdClientSocket(BaseTestStatsdClient): - def make_test_client(self, conf, *args, **kwargs): - return statsd_client.get_statsd_client(conf, *args, **kwargs) +class CommonBaseTestsMixIn(object): + + # N.B. we use a MixIn here to help maintain/transfer the understanding that + # the tests defined in this "MixIn" are run in multiple concrete TestCase + # subclasses. We can't inherit from TestCase ourselves because unittest + # does not know how to skip abstract common base TestCases - although we + # may explore alternatives in the future. + def make_test_client(self, conf, tail_prefix='', **kwargs): + """ + Concrete TestCase classes should implement this method and have the + following attributes: + * tail_prefix + * expected_prefix_bytes + """ + raise NotImplementedError() def test_ipv4_or_ipv6_hostname_defaults_to_ipv4(self): def stub_getaddrinfo_both_ipv4_and_ipv6(host, port, family, *rest): @@ -391,7 +403,7 @@ class TestGetStatsdClientSocket(BaseTestStatsdClient): client = self.make_test_client({ 'log_statsd_host': 'localhost', 'log_statsd_port': '9876', - }, 'some-name', logger=self.logger) + }, self.tail_prefix, logger=self.logger) self.assertEqual(client._sock_family, socket.AF_INET) self.assertEqual(client._target, ('localhost', 9876)) @@ -470,16 +482,6 @@ class TestGetStatsdClientSocket(BaseTestStatsdClient): # statsd sends will warn in the logs until the DNS failure or invalid # IP address in the configuration is fixed. - -class TestGetLabeledStatsdClientSocket(TestGetStatsdClientSocket): - def make_test_client(self, conf, *args, logger=None): - return statsd_client.get_labeled_statsd_client(conf, logger=logger) - - -class TestGetStatsdClientSending(BaseTestStatsdClient): - """ - Tests here use get_statsd_client to make a StatsdClient. - """ def test_sending_ipv6(self): def fake_getaddrinfo(host, port, *args): # this is what a real getaddrinfo('::1', port, @@ -496,7 +498,7 @@ class TestGetStatsdClientSending(BaseTestStatsdClient): with mock.patch.object(statsd_client.socket, 'getaddrinfo', fake_getaddrinfo): - client = get_statsd_client({ + client = self.make_test_client({ 'log_statsd_host': '::1', 'log_statsd_port': '9876', }, 'some-name', logger=self.logger) @@ -509,10 +511,11 @@ class TestGetStatsdClientSending(BaseTestStatsdClient): client.increment('tunafish') self.assertEqual(fl.get_lines_for_level('warning'), []) self.assertEqual(mock_socket.sent, - [(b'some-name.tunafish:1|c', ('::1', 9876, 0, 0))]) + [(self.expected_prefix_bytes + b'tunafish:1|c', + ('::1', 9876, 0, 0))]) def test_no_exception_when_cant_send_udp_packet(self): - client = get_statsd_client({'log_statsd_host': 'some.host.com'}) + client = self.make_test_client({'log_statsd_host': 'some.host.com'}) fl = debug_logger() client.logger = fl mock_socket = MockUdpSocket(sendto_errno=errno.EPERM) @@ -523,7 +526,7 @@ class TestGetStatsdClientSending(BaseTestStatsdClient): self.assertEqual(fl.get_lines_for_level('warning'), expected) def test_sample_rates(self): - client = get_statsd_client({'log_statsd_host': 'some.host.com'}) + client = self.make_test_client({'log_statsd_host': 'some.host.com'}) mock_socket = MockUdpSocket() self.assertTrue(client.random is random.random) @@ -532,18 +535,16 @@ class TestGetStatsdClientSending(BaseTestStatsdClient): client.random = lambda: 0.50001 self.assertIsNone(client.increment('tribbles', sample_rate=0.5)) - self.assertEqual(len(mock_socket.sent), 0) + self.assertFalse(mock_socket.sent) client.random = lambda: 0.49999 rv = client.increment('tribbles', sample_rate=0.5) self.assertIsInstance(rv, int) - self.assertEqual(len(mock_socket.sent), 1) - - payload = mock_socket.sent[0][0] - self.assertTrue(payload.endswith(b"|@0.5")) + self.assertEqual([(b"tribbles:1|c|@0.5", ('some.host.com', 8125))], + mock_socket.sent) def test_sample_rates_with_sample_rate_factor(self): - client = get_statsd_client({ + client = self.make_test_client({ 'log_statsd_host': 'some.host.com', 'log_statsd_default_sample_rate': '0.82', 'log_statsd_sample_rate_factor': '0.91', @@ -551,30 +552,55 @@ class TestGetStatsdClientSending(BaseTestStatsdClient): effective_sample_rate = 0.82 * 0.91 mock_socket = MockUdpSocket() - self.assertTrue(client.random is random.random) + self.assertIs(client.random, random.random) client._open_socket = lambda *_: mock_socket client.random = lambda: effective_sample_rate + 0.001 client.increment('tribbles') - self.assertEqual(len(mock_socket.sent), 0) + self.assertFalse(mock_socket.sent) client.random = lambda: effective_sample_rate - 0.001 client.increment('tribbles') - self.assertEqual(len(mock_socket.sent), 1) - - payload = mock_socket.sent[0][0] - suffix = ("|@%s" % effective_sample_rate).encode('utf-8') - self.assertTrue(payload.endswith(suffix), payload) + expected = ("tribbles:1|c|@%s" % effective_sample_rate).encode('utf-8') + self.assertEqual([(expected, ('some.host.com', 8125))], + mock_socket.sent) + # caller specifies non-default sample rate + mock_socket = MockUdpSocket() effective_sample_rate = 0.587 * 0.91 + client.random = lambda: effective_sample_rate + 0.001 + client.increment('tribbles', sample_rate=0.587) + self.assertFalse(mock_socket.sent) + client.random = lambda: effective_sample_rate - 0.001 client.increment('tribbles', sample_rate=0.587) - self.assertEqual(len(mock_socket.sent), 2) + expected = ("tribbles:1|c|@%s" % effective_sample_rate).encode('utf-8') + self.assertEqual([(expected, ('some.host.com', 8125))], + mock_socket.sent) - payload = mock_socket.sent[1][0] - suffix = ("|@%s" % effective_sample_rate).encode('utf-8') - self.assertTrue(payload.endswith(suffix), payload) + +class TestGetStatsdClient(BaseTestStatsdClient, CommonBaseTestsMixIn): + """ + Tests here use get_statsd_client to make a LabeledStatsdClient. + """ + tail_prefix = 'some-name' + expected_prefix_bytes = ('%s.' % tail_prefix).encode() + + def make_test_client(self, conf, tail_prefix='', **kwargs): + return statsd_client.get_statsd_client(conf, tail_prefix, **kwargs) + + +class TestGetLabeledStatsdClient(BaseTestStatsdClient, CommonBaseTestsMixIn): + """ + Tests here use get_labeled_statsd_client to make a LabeledStatsdClient. + """ + tail_prefix = None + expected_prefix_bytes = b'' + + def make_test_client(self, conf, _tail_prefix='', **kwargs): + conf.setdefault('statsd_label_mode', 'dogstatsd') + return statsd_client.get_labeled_statsd_client(conf, **kwargs) class BaseTestStatsdClientOutput(unittest.TestCase): @@ -1001,6 +1027,7 @@ class TestGetLabeledStatsdClientOutput(BaseTestStatsdClientOutput): labeled_statsd.increment, 'the_counter', labels=labels) def _do_test_statsd_methods_no_labels(self, label_mode): + # no default_sample_rate option conf = { 'log_statsd_host': 'localhost', 'log_statsd_port': str(self.port), @@ -1011,6 +1038,10 @@ class TestGetLabeledStatsdClientOutput(BaseTestStatsdClientOutput): labeled_statsd.increment, 'the.counter', labels={}) self.assertStat('the.counter:-1|c', labeled_statsd.decrement, 'the.counter', labels={}) + # but individual call sites could set sample_rate + self.assertStat('the.counter:1|c|@0.9912', + labeled_statsd.increment, 'the.counter', labels={}, + sample_rate=0.9912) self.assertStat( 'the.timing:6280.0|ms', labeled_statsd.timing, 'the.timing', 6.28 * 1000, labels={})

AltStyle によって変換されたページ (->オリジナル) /