Merge pull request #3749 from matrix-org/erikj/add_trial_users

Implement trial users
This commit is contained in:
Erik Johnston 2018-08-24 10:10:58 +01:00 committed by GitHub
commit 15e8dd2ccc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 272 additions and 5 deletions

1
changelog.d/3749.feature Normal file
View file

@ -0,0 +1 @@
Add mau_trial_days config param, so that users only get counted as MAU after N days.

View file

@ -797,11 +797,15 @@ class Auth(object):
limit_type=self.hs.config.hs_disabled_limit_type limit_type=self.hs.config.hs_disabled_limit_type
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort # If the user is already part of the MAU cohort or a trial user
if user_id: if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id) timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp: if timestamp:
return return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
# Else if there is no room in the MAU bucket, bail # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:

View file

@ -77,10 +77,15 @@ class ServerConfig(Config):
self.max_mau_value = config.get( self.max_mau_value = config.get(
"max_mau_value", 0, "max_mau_value", 0,
) )
self.mau_limits_reserved_threepids = config.get( self.mau_limits_reserved_threepids = config.get(
"mau_limit_reserved_threepids", [] "mau_limit_reserved_threepids", []
) )
self.mau_trial_days = config.get(
"mau_trial_days", 0,
)
# Options to disable HS # Options to disable HS
self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "") self.hs_disabled_message = config.get("hs_disabled_message", "")
@ -365,6 +370,7 @@ class ServerConfig(Config):
# Enables monthly active user checking # Enables monthly active user checking
# limit_usage_by_mau: False # limit_usage_by_mau: False
# max_mau_value: 50 # max_mau_value: 50
# mau_trial_days: 2
# #
# Sometimes the server admin will want to ensure certain accounts are # Sometimes the server admin will want to ensure certain accounts are
# never blocked by mau checking. These accounts are specified here. # never blocked by mau checking. These accounts are specified here.

View file

@ -201,6 +201,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
user_id(str): the user_id to query user_id(str): the user_id to query
""" """
if self.hs.config.limit_usage_by_mau: if self.hs.config.limit_usage_by_mau:
is_trial = yield self.is_trial_user(user_id)
if is_trial:
# we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()

View file

@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationWorkerStore(SQLBaseStore): class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
@cached() @cached()
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self._simple_select_one( return self._simple_select_one(
@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore):
retcols=[ retcols=[
"name", "password_hash", "is_guest", "name", "password_hash", "is_guest",
"consent_version", "consent_server_notice_sent", "consent_version", "consent_server_notice_sent",
"appservice_id", "appservice_id", "creation_ts",
], ],
allow_none=True, allow_none=True,
desc="get_user_by_id", desc="get_user_by_id",
) )
@defer.inlineCallbacks
def is_trial_user(self, user_id):
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
user_id (str)
Returns:
Deferred[bool]
"""
info = yield self.get_user_by_id(user_id)
if not info:
defer.returnValue(False)
now = self.clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
defer.returnValue(is_trial)
@cached() @cached()
def get_user_by_access_token(self, token): def get_user_by_access_token(self, token):
"""Get a user from the given access token. """Get a user from the given access token.

View file

@ -5,7 +5,7 @@ from six import text_type
import attr import attr
from twisted.internet import threads from twisted.internet import address, threads
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
@ -63,7 +63,9 @@ class FakeChannel(object):
self.result["done"] = True self.result["done"] = True
def getPeer(self): def getPeer(self):
return None # We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address(b"TCP", "127.0.0.1", 3423)
def getHost(self): def getHost(self):
return None return None
@ -91,7 +93,7 @@ class FakeSite:
return FakeLogger() return FakeLogger()
def make_request(method, path, content=b""): def make_request(method, path, content=b"", access_token=None):
""" """
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
@ -116,6 +118,11 @@ def make_request(method, path, content=b""):
req = SynapseRequest(site, channel) req = SynapseRequest(site, channel)
req.process = lambda: b"" req.process = lambda: b""
req.content = BytesIO(content) req.content = BytesIO(content)
if access_token:
req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
req.requestReceived(method, path, b"1.1") req.requestReceived(method, path, b"1.1")
return req, channel return req, channel

View file

@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None, "consent_version": None,
"consent_server_notice_sent": None, "consent_server_notice_sent": None,
"appservice_id": None, "appservice_id": None,
"creation_ts": 1000,
}, },
(yield self.store.get_user_by_id(self.user_id)), (yield self.store.get_user_by_id(self.user_id)),
) )

217
tests/test_mau.py Normal file
View file

@ -0,0 +1,217 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests REST events for /rooms paths."""
import json
from mock import Mock, NonCallableMock
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import register, sync
from synapse.util import Clock
from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class TestMauLimit(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
self.clock = Clock(self.reactor)
self.hs = setup_test_homeserver(
self.addCleanup,
"red",
http_client=None,
clock=self.clock,
reactor=self.reactor,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.store = self.hs.get_datastore()
self.hs.config.registrations_require_3pid = []
self.hs.config.enable_registration_captcha = False
self.hs.config.recaptcha_public_key = []
self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2
self.hs.config.mau_trial_days = 0
self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room"
self.resource = JsonResource(self.hs)
register.register_servlets(self.hs, self.resource)
sync.register_servlets(self.hs, self.resource)
def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
self.create_user("kermit3")
e = cm.exception
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
# Advance time by 31 days
self.reactor.advance(31 * 24 * 60 * 60)
self.store.reap_monthly_active_users()
self.reactor.advance(0)
# We should be able to register more users
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
def test_trial_delay(self):
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
# Advance time by 2 days
self.reactor.advance(2 * 24 * 60 * 60)
# Two users should be able to sync
self.do_sync_for_user(token1)
self.do_sync_for_user(token2)
# But the third should fail
with self.assertRaises(SynapseError) as cm:
self.do_sync_for_user(token3)
e = cm.exception
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
# And new registrations are now denied too
with self.assertRaises(SynapseError) as cm:
self.create_user("kermit4")
e = cm.exception
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_trial_users_cant_come_back(self):
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
token3 = self.create_user("kermit3")
self.do_sync_for_user(token3)
# Advance time by 2 days
self.reactor.advance(2 * 24 * 60 * 60)
# Two users should be able to sync
self.do_sync_for_user(token1)
self.do_sync_for_user(token2)
# Advance by 2 months so everyone falls out of MAU
self.reactor.advance(60 * 24 * 60 * 60)
self.store.reap_monthly_active_users()
self.reactor.advance(0)
# We can create as many new users as we want
token4 = self.create_user("kermit4")
self.do_sync_for_user(token4)
token5 = self.create_user("kermit5")
self.do_sync_for_user(token5)
token6 = self.create_user("kermit6")
self.do_sync_for_user(token6)
# users 2 and 3 can come back to bring us back up to MAU limit
self.do_sync_for_user(token2)
self.do_sync_for_user(token3)
# New trial users can still sync
self.do_sync_for_user(token4)
self.do_sync_for_user(token5)
self.do_sync_for_user(token6)
# But old user cant
with self.assertRaises(SynapseError) as cm:
self.do_sync_for_user(token1)
e = cm.exception
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def create_user(self, localpart):
request_data = json.dumps({
"username": localpart,
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
})
request, channel = make_request(b"POST", b"/register", request_data)
render(request, self.resource, self.reactor)
if channel.result["code"] != b"200":
raise HttpResponseException(
int(channel.result["code"]),
channel.result["reason"],
channel.result["body"],
).to_synapse_error()
access_token = channel.json_body["access_token"]
return access_token
def do_sync_for_user(self, token):
request, channel = make_request(b"GET", b"/sync", access_token=token)
render(request, self.resource, self.reactor)
if channel.result["code"] != b"200":
raise HttpResponseException(
int(channel.result["code"]),
channel.result["reason"],
channel.result["body"],
).to_synapse_error()