Skip to content

docs: A websocket API example. #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/complete-websocket/connection_handler.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module "lambda_function" {
source = "terraform-aws-modules/lambda/aws"

function_name = "aws-ws-connection-handler"
description = "AWS WS connection handler"
handler = "handler.lambda_handler"
runtime = "python3.8"

publish = true

source_path = "./handler"

allowed_triggers = {
AllowExecutionFromAPIGateway = {
service = "apigateway"
source_arn = "${module.api_gateway.this_apigatewayv2_api_execution_arn}/*/*"
}
}

attach_policy_statements = true
policy_statements = {
manage_connections = {
effect = "Allow",
actions = ["execute-api:ManageConnections"],
resources = ["${module.api_gateway.default_apigatewayv2_stage_execution_arn}/*"]
}
dynamodb = {
effect = "Allow",
actions = ["dynamodb:GetItem", "dynamodb:PutItem", "dynamodb:DeleteItem"],
resources = [module.dynamodb_table.this_dynamodb_table_arn]
}

}

}
18 changes: 18 additions & 0 deletions examples/complete-websocket/dynamodb.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module "dynamodb_table" {
source = "terraform-aws-modules/dynamodb-table/aws"

name = "aws-ws-connections"
hash_key = "connection_id"

attributes = [
{
name = "connection_id"
type = "S"
}
]

tags = {
Terraform = "true"
}
}

8 changes: 8 additions & 0 deletions examples/complete-websocket/events.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#### Event bus

resource "aws_cloudwatch_event_rule" "heartbeat" {
name = "aws-ws-heartbeart"
description = "Ping connected Websocket clients"
schedule_expression = "rate(1 minute)"
}

190 changes: 190 additions & 0 deletions examples/complete-websocket/handler/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Modified from AWS example code: https://docs.aws.amazon.com/code-samples/latest/catalog/python-cross_service-apigateway_websocket_chat-lambda_chat.py.html
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Purpose

Shows how to implement an AWS Lambda function as part of a websocket chat application.
The function handles messages from an Amazon API Gateway websocket API and uses an
Amazon DynamoDB table to track active connections. When a message is sent by any
participant, it is posted to all other active connections by using the Amazon
API Gateway Management API.

Logs written by this handler can be found in Amazon CloudWatch.
"""

import json
import logging
import os
import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger()
logger.setLevel(logging.INFO)


def handle_connect(table, connection_id):
"""
Handles new connections by adding the connection ID and user name to the
DynamoDB table.

:param user_name: The name of the user that started the connection.
:param table: The DynamoDB connection table.
:param connection_id: The websocket connection ID of the new connection.
:return: An HTTP status code that indicates the result of adding the connection
to the DynamoDB table.
"""
status_code = 200
try:
table.put_item(Item={"connection_id": connection_id})
logger.info("Added connection %s", connection_id)
except ClientError:
logger.exception("Couldn't add connection %s", connection_id)
status_code = 503
return status_code


def handle_disconnect(table, connection_id):
"""
Handles disconnections by removing the connection record from the DynamoDB table.

:param table: The DynamoDB connection table.
:param connection_id: The websocket connection ID of the connection to remove.
:return: An HTTP status code that indicates the result of removing the connection
from the DynamoDB table.
"""
status_code = 200
try:
table.delete_item(Key={"connection_id": connection_id})
logger.info("Disconnected connection %s.", connection_id)
except ClientError:
logger.exception("Couldn't disconnect connection %s.", connection_id)
status_code = 503
return status_code


def handle_message(table, connection_id, event_body, apig_management_client):
"""
Handles messages sent by a participant in the chat. Looks up all connections
currently tracked in the DynamoDB table, and uses the API Gateway Management API
to post the message to each other connection.

When posting to a connection results in a GoneException, the connection is
considered disconnected and is removed from the table. This is necessary
because disconnect messages are not always sent when a client disconnects.

:param table: The DynamoDB connection table.
:param connection_id: The ID of the connection that sent the message.
:param event_body: The body of the message sent from API Gateway. This is a
dict with a `msg` field that contains the message to send.
:param apig_management_client: A Boto3 API Gateway Management API client.
:return: An HTTP status code that indicates the result of posting the message
to all active connections.
"""
status_code = 200
user_name = "guest"
try:
item_response = table.get_item(Key={"connection_id": connection_id})
user_name = item_response["Item"]["user_name"]
logger.info("Got user name %s.", user_name)
except ClientError:
logger.exception("Couldn't find user name. Using %s.", user_name)

connection_ids = []
try:
scan_response = table.scan(ProjectionExpression="connection_id")
connection_ids = [item["connection_id"] for item in scan_response["Items"]]
logger.info("Found %s active connections.", len(connection_ids))
except ClientError:
logger.exception("Couldn't get connections.")
status_code = 404

message = f"{user_name}: {event_body['msg']}".encode("utf-8")
logger.info("Message: %s", message)

for other_conn_id in connection_ids:
try:
if other_conn_id != connection_id:
send_response = apig_management_client.post_to_connection(
Data=message, ConnectionId=other_conn_id
)
logger.info(
"Posted message to connection %s, got response %s.",
other_conn_id,
send_response,
)
except ClientError:
logger.exception("Couldn't post to connection %s.", other_conn_id)
except apig_management_client.exceptions.GoneException:
logger.info("Connection %s is gone, removing.", other_conn_id)
try:
table.delete_item(Key={"connection_id": other_conn_id})
except ClientError:
logger.exception("Couldn't remove connection %s.", other_conn_id)

return status_code


def lambda_handler(event, context):
"""
An AWS Lambda handler that receives events from an API Gateway websocket API
and dispatches them to various handler functions.

This function looks up the name of a DynamoDB table in the `table_name` environment
variable. The table must have a primary key named `connection_id`.

This function handles three routes: $connect, $disconnect, and sendmessage. Any
other route results in a 404 status code.

The $connect route accepts a query string `name` parameter that is the name of
the user that originated the connection. This name is added to all chat messages
sent by that user.

:param event: A dict that contains request data, query string parameters, and
other data sent by API Gateway.
:param context: Context around the request.
:return: A response dict that contains an HTTP status code that indicates the
result of handling the event.
"""

table_name = "aws-ws-connections"
route_key = event.get("requestContext", {}).get("routeKey")
logger.info(route_key)
connection_id = event.get("requestContext", {}).get("connectionId")
if table_name is None or route_key is None or connection_id is None:
return {"statusCode": 400}

table = boto3.resource("dynamodb").Table(table_name)
logger.info("Request: %s, use table %s.", route_key, table.name)

response = {"statusCode": 200}
if route_key == "$connect":
response["statusCode"] = handle_connect(table, connection_id)
elif route_key == "$disconnect":
response["statusCode"] = handle_disconnect(table, connection_id)
elif route_key == "$default":
body = event.get("body", "NO MESSAGE")
domain = event.get("requestContext", {}).get("domainName")
stage = event.get("requestContext", {}).get("stage")
if domain is None or stage is None:
logger.warning(
"Couldn't send message. Bad endpoint in request: domain '%s', "
"stage '%s'",
domain,
stage,
)
response["statusCode"] = 400
else:
apig_management_client = boto3.client(
"apigatewaymanagementapi", endpoint_url=f"https://{domain}/{stage}"
)
apig_management_client.post_to_connection(
Data=body + "right back at yah",
ConnectionId=connection_id,
)

else:
response["statusCode"] = 404

return response
45 changes: 45 additions & 0 deletions examples/complete-websocket/heartbeat.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module "heartbeat_function" {
source = "terraform-aws-modules/lambda/aws"

function_name = "aws-ws-heartbeat"
description = "AWS WS Test Heartbeart"
handler = "handler.lambda_handler"
runtime = "python3.8"

publish = true

source_path = "./heartbeat"

allowed_triggers = {
RunHeartbeat = {
principal = "events.amazonaws.com"
source_arn = aws_cloudwatch_event_rule.heartbeat.arn
}
}

attach_policy_statements = true
policy_statements = {
manage_connections = {
effect = "Allow",
actions = ["execute-api:ManageConnections"],
resources = ["${module.api_gateway.default_apigatewayv2_stage_execution_arn}/*"]
}
dynamodb = {
effect = "Allow",
actions = ["dynamodb:GetItem", "dynamodb:Scan"],
resources = [module.dynamodb_table.this_dynamodb_table_arn]
}
}
}


resource "aws_cloudwatch_event_target" "send_heartbeat" {
arn = module.heartbeat_function.this_lambda_function_arn
rule = aws_cloudwatch_event_rule.heartbeat.name
}

resource "local_file" "api_domain" {
filename = "heartbeat/api_url"
file_permission = "0666"
content = "${replace(module.api_gateway.default_apigatewayv2_stage_invoke_url, "wss", "https")}/"
}
56 changes: 56 additions & 0 deletions examples/complete-websocket/heartbeat/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from datetime import datetime, timezone
from pathlib import Path

import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger()
logger.setLevel(logging.INFO)


API_URL = Path("./api_url").read_text()


def lambda_handler(event, _):
table_name = "aws-ws-connections"

connection_id = event.get("requestContext", {}).get("connectionId")

table = boto3.resource("dynamodb").Table(table_name)

connection_ids = []

try:
scan_response = table.scan(ProjectionExpression="connection_id")
connection_ids = [item["connection_id"] for item in scan_response["Items"]]
logger.info("Found %s active connections.", len(connection_ids))
except ClientError:
logger.exception("Couldn't get connections.")

message = f"PING? {datetime.now(tz=timezone.utc)}"
logger.info("Message: %s", message)

apig_management_client = boto3.client(
"apigatewaymanagementapi", endpoint_url=API_URL
)

for other_conn_id in connection_ids:
try:
if other_conn_id != connection_id:
send_response = apig_management_client.post_to_connection(
Data=message, ConnectionId=other_conn_id
)
logger.info(
"Posted message to connection %s, got response %s.",
other_conn_id,
send_response,
)
except ClientError:
logger.exception("Couldn't post to connection %s.", other_conn_id)
except apig_management_client.exceptions.GoneException:
logger.info("Connection %s is gone, removing.", other_conn_id)
try:
table.delete_item(Key={"connection_id": other_conn_id})
except ClientError:
logger.exception("Couldn't remove connection %s.", other_conn_id)
Loading