diff --git a/lib/realtime_web/channels/realtime_channel.ex b/lib/realtime_web/channels/realtime_channel.ex index 255b6edfd..5d1455991 100644 --- a/lib/realtime_web/channels/realtime_channel.ex +++ b/lib/realtime_web/channels/realtime_channel.ex @@ -388,6 +388,9 @@ defmodule RealtimeWeb.RealtimeChannel do {:error, :rate_limit_exceeded} -> shutdown_response(socket, "Too many presence messages per second") + {:error, :payload_size_exceeded} -> + shutdown_response(socket, "Track message size exceeded") + {:error, error} -> log_error(socket, "UnableToHandlePresence", error) {:reply, :error, socket} @@ -401,6 +404,9 @@ defmodule RealtimeWeb.RealtimeChannel do {:error, :rate_limit_exceeded} -> shutdown_response(socket, "Too many presence messages per second") + {:error, :payload_size_exceeded} -> + shutdown_response(socket, "Track message size exceeded") + {:error, error} -> log_error(socket, "UnableToHandlePresence", error) {:reply, :error, socket} diff --git a/lib/realtime_web/channels/realtime_channel/presence_handler.ex b/lib/realtime_web/channels/realtime_channel/presence_handler.ex index ec16c7b16..7880605ca 100644 --- a/lib/realtime_web/channels/realtime_channel/presence_handler.ex +++ b/lib/realtime_web/channels/realtime_channel/presence_handler.ex @@ -11,7 +11,7 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandler do alias Phoenix.Tracker.Shard alias Realtime.GenCounter alias Realtime.RateCounter - # alias Realtime.Tenants + alias Realtime.Tenants alias Realtime.Tenants.Authorization alias RealtimeWeb.Presence alias RealtimeWeb.RealtimeChannel.Logging @@ -54,7 +54,12 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandler do @spec handle(map(), pid() | nil, Socket.t()) :: {:ok, Socket.t()} - | {:error, :rls_policy_error | :unable_to_set_policies | :rate_limit_exceeded | :unable_to_track_presence} + | {:error, + :rls_policy_error + | :unable_to_set_policies + | :rate_limit_exceeded + | :unable_to_track_presence + | :payload_size_exceeded} def handle(%{"event" => event} = payload, db_conn, socket) do event = String.downcase(event, :ascii) handle_presence_event(event, payload, db_conn, socket) @@ -105,14 +110,15 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandler do end defp track(socket, payload) do - socket = assign(socket, :presence_enabled?, true) - %{assigns: %{presence_key: presence_key, tenant_topic: tenant_topic}} = socket payload = Map.get(payload, "payload", %{}) - RealtimeWeb.TenantBroadcaster.collect_payload_size(socket.assigns.tenant, payload, :presence) - with :ok <- limit_presence_event(socket), + with tenant <- Tenants.Cache.get_tenant_by_external_id(socket.assigns.tenant), + :ok <- validate_payload_size(tenant, payload), + _ <- RealtimeWeb.TenantBroadcaster.collect_payload_size(socket.assigns.tenant, payload, :presence), + :ok <- limit_presence_event(socket), {:ok, _} <- Presence.track(self(), tenant_topic, presence_key, payload) do + socket = assign(socket, :presence_enabled?, true) {:ok, socket} else {:error, {:already_tracked, pid, _, _}} -> @@ -124,6 +130,9 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandler do {:error, :rate_limit_exceeded} -> {:error, :rate_limit_exceeded} + {:error, :payload_size_exceeded} -> + {:error, :payload_size_exceeded} + {:error, error} -> log_error("UnableToTrackPresence", error) {:error, :unable_to_track_presence} @@ -144,8 +153,6 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandler do %{assigns: %{presence_rate_counter: presence_counter, tenant: _tenant_id}} = socket {:ok, rate_counter} = RateCounter.get(presence_counter) - # tenant = Tenants.Cache.get_tenant_by_external_id(tenant_id) - if rate_counter.avg > @presence_limit do {:error, :rate_limit_exceeded} else @@ -153,4 +160,14 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandler do :ok end end + + # Added due to the fact that JSON decoding adds some overhead and erlang term will be slighly larger + @payload_size_padding 500 + defp validate_payload_size(tenant, payload) do + if :erlang.external_size(payload) > tenant.max_payload_size_in_kb * 1000 + @payload_size_padding do + {:error, :payload_size_exceeded} + else + :ok + end + end end diff --git a/mix.exs b/mix.exs index 9e4ea1736..7c9402619 100644 --- a/mix.exs +++ b/mix.exs @@ -4,7 +4,7 @@ defmodule Realtime.MixProject do def project do [ app: :realtime, - version: "2.54.2", + version: "2.54.3", elixir: "~> 1.17.3", elixirc_paths: elixirc_paths(Mix.env()), start_permanent: Mix.env() == :prod, diff --git a/test/realtime/database_distributed_test.exs b/test/realtime/database_distributed_test.exs new file mode 100644 index 000000000..43b40743e --- /dev/null +++ b/test/realtime/database_distributed_test.exs @@ -0,0 +1,100 @@ +defmodule Realtime.DatabaseDistributedTest do + # async: false due to usage of Clustered + dev_tenant + use Realtime.DataCase, async: false + + import ExUnit.CaptureLog + + alias Realtime.Database + alias Realtime.Rpc + alias Realtime.Tenants.Connect + + doctest Realtime.Database + def handle_telemetry(event, metadata, content, pid: pid), do: send(pid, {event, metadata, content}) + + setup do + tenant = Containers.checkout_tenant() + :telemetry.attach(__MODULE__, [:realtime, :database, :transaction], &__MODULE__.handle_telemetry/4, pid: self()) + + on_exit(fn -> :telemetry.detach(__MODULE__) end) + + %{tenant: tenant} + end + + @aux_mod (quote do + defmodule DatabaseAux do + def checker(transaction_conn) do + Postgrex.query!(transaction_conn, "SELECT 1", []) + end + + def error(transaction_conn) do + Postgrex.query!(transaction_conn, "SELECT 1/0", []) + end + + def exception(_) do + raise RuntimeError, "💣" + end + end + end) + + Code.eval_quoted(@aux_mod) + + describe "transaction/1 in clustered mode" do + setup do + Connect.shutdown("dev_tenant") + # Waiting for :syn to "unregister" if the Connect process was up + Process.sleep(100) + :ok + end + + test "success call returns output" do + {:ok, node} = Clustered.start(@aux_mod) + {:ok, db_conn} = Rpc.call(node, Connect, :connect, ["dev_tenant", "us-east-1"]) + assert node(db_conn) == node + assert {:ok, %Postgrex.Result{rows: [[1]]}} = Database.transaction(db_conn, &DatabaseAux.checker/1) + end + + test "handles database errors" do + metadata = [external_id: "123", project: "123"] + {:ok, node} = Clustered.start(@aux_mod) + {:ok, db_conn} = Rpc.call(node, Connect, :connect, ["dev_tenant", "us-east-1"]) + assert node(db_conn) == node + + assert capture_log(fn -> + assert {:error, %Postgrex.Error{}} = Database.transaction(db_conn, &DatabaseAux.error/1, [], metadata) + # We have to wait for logs to be relayed to this node + Process.sleep(100) + end) =~ "project=123 external_id=123 [error] ErrorExecutingTransaction:" + end + + test "handles exception" do + metadata = [external_id: "123", project: "123"] + {:ok, node} = Clustered.start(@aux_mod) + {:ok, db_conn} = Rpc.call(node, Connect, :connect, ["dev_tenant", "us-east-1"]) + assert node(db_conn) == node + + assert capture_log(fn -> + assert {:error, %RuntimeError{}} = Database.transaction(db_conn, &DatabaseAux.exception/1, [], metadata) + # We have to wait for logs to be relayed to this node + Process.sleep(100) + end) =~ "project=123 external_id=123 [error] ErrorExecutingTransaction:" + end + + test "db process is not alive anymore" do + metadata = [external_id: "123", project: "123", tenant_id: "123"] + {:ok, node} = Clustered.start(@aux_mod) + # Grab a remote pid that will not exist. :erpc uses a new process to perform the call. + # Once it has returned the process is not alive anymore + + pid = Rpc.call(node, :erlang, :self, []) + assert node(pid) == node + + assert capture_log(fn -> + assert {:error, {:exit, {:noproc, {DBConnection.Holder, :checkout, [^pid, []]}}}} = + Database.transaction(pid, &DatabaseAux.checker/1, [], metadata) + + # We have to wait for logs to be relayed to this node + Process.sleep(100) + end) =~ "project=123 external_id=123 [error] ErrorExecutingTransaction:" + end + end +end diff --git a/test/realtime/database_test.exs b/test/realtime/database_test.exs index f48de14b6..df4e63456 100644 --- a/test/realtime/database_test.exs +++ b/test/realtime/database_test.exs @@ -1,12 +1,9 @@ defmodule Realtime.DatabaseTest do - # async: false due to usage of Clustered - use Realtime.DataCase, async: false + use Realtime.DataCase, async: true import ExUnit.CaptureLog alias Realtime.Database - alias Realtime.Rpc - alias Realtime.Tenants.Connect doctest Realtime.Database def handle_telemetry(event, metadata, content, pid: pid), do: send(pid, {event, metadata, content}) @@ -215,84 +212,6 @@ defmodule Realtime.DatabaseTest do end end - @aux_mod (quote do - defmodule DatabaseAux do - def checker(transaction_conn) do - Postgrex.query!(transaction_conn, "SELECT 1", []) - end - - def error(transaction_conn) do - Postgrex.query!(transaction_conn, "SELECT 1/0", []) - end - - def exception(_) do - raise RuntimeError, "💣" - end - end - end) - - Code.eval_quoted(@aux_mod) - - describe "transaction/1 in clustered mode" do - setup do - Connect.shutdown("dev_tenant") - # Waiting for :syn to "unregister" if the Connect process was up - Process.sleep(100) - :ok - end - - test "success call returns output" do - {:ok, node} = Clustered.start(@aux_mod) - {:ok, db_conn} = Rpc.call(node, Connect, :connect, ["dev_tenant", "us-east-1"]) - assert node(db_conn) == node - assert {:ok, %Postgrex.Result{rows: [[1]]}} = Database.transaction(db_conn, &DatabaseAux.checker/1) - end - - test "handles database errors" do - metadata = [external_id: "123", project: "123"] - {:ok, node} = Clustered.start(@aux_mod) - {:ok, db_conn} = Rpc.call(node, Connect, :connect, ["dev_tenant", "us-east-1"]) - assert node(db_conn) == node - - assert capture_log(fn -> - assert {:error, %Postgrex.Error{}} = Database.transaction(db_conn, &DatabaseAux.error/1, [], metadata) - # We have to wait for logs to be relayed to this node - Process.sleep(100) - end) =~ "project=123 external_id=123 [error] ErrorExecutingTransaction:" - end - - test "handles exception" do - metadata = [external_id: "123", project: "123"] - {:ok, node} = Clustered.start(@aux_mod) - {:ok, db_conn} = Rpc.call(node, Connect, :connect, ["dev_tenant", "us-east-1"]) - assert node(db_conn) == node - - assert capture_log(fn -> - assert {:error, %RuntimeError{}} = Database.transaction(db_conn, &DatabaseAux.exception/1, [], metadata) - # We have to wait for logs to be relayed to this node - Process.sleep(100) - end) =~ "project=123 external_id=123 [error] ErrorExecutingTransaction:" - end - - test "db process is not alive anymore" do - metadata = [external_id: "123", project: "123", tenant_id: "123"] - {:ok, node} = Clustered.start(@aux_mod) - # Grab a remote pid that will not exist. :erpc uses a new process to perform the call. - # Once it has returned the process is not alive anymore - - pid = Rpc.call(node, :erlang, :self, []) - assert node(pid) == node - - assert capture_log(fn -> - assert {:error, {:exit, {:noproc, {DBConnection.Holder, :checkout, [^pid, []]}}}} = - Database.transaction(pid, &DatabaseAux.checker/1, [], metadata) - - # We have to wait for logs to be relayed to this node - Process.sleep(100) - end) =~ "project=123 external_id=123 [error] ErrorExecutingTransaction:" - end - end - describe "pool_size_by_application_name/2" do test "returns the number of connections per application name" do assert Database.pool_size_by_application_name("realtime_connect", %{}) == 1 diff --git a/test/realtime/messages_test.exs b/test/realtime/messages_test.exs index cca0ce742..9b99a5580 100644 --- a/test/realtime/messages_test.exs +++ b/test/realtime/messages_test.exs @@ -1,5 +1,6 @@ defmodule Realtime.MessagesTest do - use Realtime.DataCase, async: true + # usage of Clustered + use Realtime.DataCase, async: false alias Realtime.Api.Message alias Realtime.Database diff --git a/test/realtime_web/channels/realtime_channel/presence_handler_test.exs b/test/realtime_web/channels/realtime_channel/presence_handler_test.exs index 219f13e55..69f89f36a 100644 --- a/test/realtime_web/channels/realtime_channel/presence_handler_test.exs +++ b/test/realtime_web/channels/realtime_channel/presence_handler_test.exs @@ -403,6 +403,17 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandlerTest do assert log =~ "PresenceRateLimitReached" end + + test "fails on high payload size", %{tenant: tenant, topic: topic, db_conn: db_conn} do + key = random_string() + socket = socket_fixture(tenant, topic, key, private?: false) + payload_size = tenant.max_payload_size_in_kb * 1000 + + payload = %{content: random_string(payload_size)} + + assert {:error, :payload_size_exceeded} = + PresenceHandler.handle(%{"event" => "track", "payload" => payload}, db_conn, socket) + end end describe "sync/1" do @@ -461,7 +472,6 @@ defmodule RealtimeWeb.RealtimeChannel.PresenceHandlerTest do assert log =~ "PresenceRateLimitReached" end - @tag :skip @tag policies: [:authenticated_read_broadcast_and_presence, :authenticated_write_broadcast_and_presence] test "respects rate limits on private channels", %{tenant: tenant, topic: topic, db_conn: db_conn} do key = random_string() diff --git a/test/realtime_web/channels/realtime_channel_test.exs b/test/realtime_web/channels/realtime_channel_test.exs index 055516e64..16e337af8 100644 --- a/test/realtime_web/channels/realtime_channel_test.exs +++ b/test/realtime_web/channels/realtime_channel_test.exs @@ -1,6 +1,5 @@ defmodule RealtimeWeb.RealtimeChannelTest do - # Can't run async true because under the hood Cachex is used and it doesn't see Ecto Sandbox - use RealtimeWeb.ChannelCase, async: false + use RealtimeWeb.ChannelCase, async: true use Mimic import ExUnit.CaptureLog @@ -23,6 +22,7 @@ defmodule RealtimeWeb.RealtimeChannelTest do setup do tenant = Containers.checkout_tenant(run_migrations: true) + Cachex.put!(Realtime.Tenants.Cache, {{:get_tenant_by_external_id, 1}, [tenant.external_id]}, {:cached, tenant}) {:ok, tenant: tenant} end @@ -273,6 +273,35 @@ defmodule RealtimeWeb.RealtimeChannelTest do # presence_state assert Enum.sum(bucket) == 1 end + + test "presence track closes on high payload size", %{tenant: tenant} do + topic = "realtime:test" + jwt = Generators.generate_jwt_token(tenant) + {:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt)) + + assert {:ok, _, %Socket{} = socket} = subscribe_and_join(socket, topic, %{}) + + assert_receive %Phoenix.Socket.Message{topic: "realtime:test", event: "presence_state"}, 500 + + payload = %{ + type: "presence", + event: "TRACK", + payload: %{name: "realtime_presence_96", t: 1814.7000000029802, content: String.duplicate("a", 3_500_000)} + } + + push(socket, "presence", payload) + + assert_receive %Phoenix.Socket.Message{ + event: "system", + payload: %{ + extension: "system", + message: "Track message size exceeded", + status: "error" + }, + topic: ^topic + }, + 500 + end end describe "unexpected errors" do @@ -978,7 +1007,10 @@ defmodule RealtimeWeb.RealtimeChannelTest do put_in(extension, ["settings", "db_port"], db_port) ] - Realtime.Api.update_tenant(tenant, %{extensions: extensions}) + with {:ok, tenant} <- Realtime.Api.update_tenant(tenant, %{extensions: extensions}) do + Cachex.put!(Realtime.Tenants.Cache, {{:get_tenant_by_external_id, 1}, [tenant.external_id]}, {:cached, tenant}) + {:ok, tenant} + end end defp assert_process_down(pid) do diff --git a/test/test_helper.exs b/test/test_helper.exs index c97eaa0b2..002e01b13 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -2,7 +2,7 @@ start_time = :os.system_time(:millisecond) alias Realtime.Api alias Realtime.Database -ExUnit.start(exclude: [:failing], max_cases: 3, capture_log: true) +ExUnit.start(exclude: [:failing], max_cases: 4, capture_log: true) max_cases = ExUnit.configuration()[:max_cases]