Refactor of REST server module

parent e88ae803
......@@ -6,12 +6,12 @@ import logging
import sys
import signal
from colorlog import info
from importlib import import_module
from yangson.enumerations import ContentType
from . import usr_op_handlers, usr_state_data_handlers
from .rest_server import RestServer
from .config import CONFIG, load_config, print_config
from .nacm import NacmConfig
from .data import JsonDatastore
from .helpers import DataHelpers
from .handler_list import OP_HANDLERS, STATE_DATA_HANDLES, CONF_DATA_HANDLES
......
......@@ -20,6 +20,7 @@ CONFIG_HTTP = {
"API_ROOT_STAGING": "/restconf_staging",
"SERVER_NAME": "hyper-h2",
"UPLOAD_SIZE_LIMIT": 1,
"LISTEN_LOCALHOST_ONLY": False,
"PORT": 8443,
"SERVER_SSL_CERT": "server.crt",
......
......@@ -21,7 +21,7 @@ from yangson.instance import (
InstanceRoute
)
from .helpers import PathFormat, ErrorHelpers, LogHelpers
from .helpers import PathFormat, ErrorHelpers, LogHelpers, DataHelpers
from .config import CONFIG
epretty = ErrorHelpers.epretty
......@@ -72,7 +72,11 @@ class NoHandlerError(HandlerError):
class NoHandlerForOpError(NoHandlerError):
pass
def __init__(self, op_name: str):
self.op_name = op_name
def __str__(self):
return "Nonexistent handler for operation \"{}\"".format(self.op_name)
class NoHandlerForStateDataError(NoHandlerError):
......@@ -216,8 +220,8 @@ class BaseDatastore:
self._data_lock = Lock()
self._lock_username = None # type: str
self._usr_journals = {} # type: Dict[str, UsrChangeJournal]
self.commit_begin_callback = None # type: Callable
self.commit_end_callback = None # type: Callable
self.commit_begin_callback = None # type: Callable[..., None]
self.commit_end_callback = None # type: Callable[..., None]
if with_nacm:
self.nacm = NacmConfig(self)
......@@ -274,7 +278,7 @@ class BaseDatastore:
try:
sch_pth_list = filter(lambda n: isinstance(n, MemberName), ii)
sch_pth = "".join([str(seg) for seg in sch_pth_list])
sch_pth = DataHelpers.ii2str(sch_pth_list)
sn = self.get_schema_node(sch_pth)
while sn is not None:
......@@ -312,7 +316,7 @@ class BaseDatastore:
n = root.goto(ii)
sch_pth_list = filter(lambda n: isinstance(n, MemberName), ii)
sch_pth = "".join([str(seg) for seg in sch_pth_list])
sch_pth = DataHelpers.ii2str(sch_pth_list)
sn = self.get_schema_node(sch_pth)
state_roots = sn.state_roots()
......@@ -594,7 +598,7 @@ class BaseDatastore:
else:
op_handler = OP_HANDLERS.get_handler(rpc.op_name)
if op_handler is None:
raise NoHandlerForOpError()
raise NoHandlerForOpError(rpc.op_name)
# Print operation input schema
# sn = self.get_schema_node(rpc.path)
......
......@@ -2,7 +2,7 @@ import logging
from colorlog import debug, getLogger
from enum import Enum
from typing import Dict, Any
from typing import Dict, Any, Iterable
from datetime import datetime
from pytz import timezone
from yangson.instance import InstanceRoute, MemberName, EntryKeys, InstanceIdParser, ResourceIdParser
......@@ -10,6 +10,8 @@ from yangson.datamodel import DataModel
from .config import CONFIG_GLOBAL, CONFIG_HTTP
SSLCertT = Dict[str, Any]
class PathFormat(Enum):
URL = 0
......@@ -18,7 +20,7 @@ class PathFormat(Enum):
class CertHelpers:
@staticmethod
def get_field(cert: Dict[str, Any], key: str) -> str:
def get_field(cert: SSLCertT, key: str) -> str:
if CONFIG_HTTP["DBG_DISABLE_CERTS"] and (key == "emailAddress"):
return "test-user"
......@@ -68,6 +70,11 @@ class DataHelpers:
return ii
# Convert InstanceRoute or List[InstanceSelector] to string
@staticmethod
def ii2str(ii: Iterable) -> str:
return "".join([str(seg) for seg in ii])
class DateTimeHelpers:
@staticmethod
......
This diff is collapsed.
......@@ -353,7 +353,7 @@ class UserNacm:
# debug_nacm("checking mii {}".format(mii))
if self.check_data_node_path(root, mii, Permission.NACM_ACCESS_READ) == Action.DENY:
# debug_nacm("Pruning node {} {}".format(id(node.value[child_key]), node.value[child_key]))
debug_nacm("Pruning node {}".format(mii))
debug_nacm("Pruning node {}".format(DataHelpers.ii2str(mii)))
node = node.delete_member(child_key)
else:
node = self._check_data_read_path(m, root, mii).up()
......@@ -370,7 +370,7 @@ class UserNacm:
# debug_nacm("checking eii {}".format(eii))
if self.check_data_node_path(root, eii, Permission.NACM_ACCESS_READ) == Action.DENY:
# debug_nacm("Pruning node {} {}".format(id(node.value[i]), node.value[i]))
debug_nacm("Pruning node {}".format(eii))
debug_nacm("Pruning node {}".format(DataHelpers.ii2str(eii)))
node = node.delete_entry(i)
arr_len -= 1
else:
......@@ -385,15 +385,13 @@ class UserNacm:
else:
return self._check_data_read_path(node, root, ii)
def check_rpc_name(self, rpc_name: str, out_matching_rule: List[NacmRule] = None) -> Action:
def check_rpc_name(self, rpc_name: str) -> Action:
if not self.nacm_enabled:
return Action.PERMIT
for rl in self.rule_lists:
for rpc_rule in filter(lambda r: r.type == NacmRuleType.NACM_RULE_OPERATION, rl.rules):
if rpc_name in rpc_rule.type_data.rpc_names:
if out_matching_rule is not None:
out_matching_rule.append(rpc_rule)
return rpc_rule.action
return self.default_exec
......
......@@ -3,21 +3,23 @@ import ssl
from io import BytesIO
from collections import OrderedDict
from colorlog import error, warning as warn, info
from typing import List, Tuple, Dict, Any, Callable
from typing import List, Tuple, Dict, Any, Callable, Optional
from h2.connection import H2Connection
from h2.errors import PROTOCOL_ERROR, ENHANCE_YOUR_CALM
from h2.events import DataReceived, RequestReceived, RemoteSettingsChanged, StreamEnded
import jetconf.http_handlers as handlers
from . import http_handlers as handlers
from .http_handlers import HttpResponse, HttpStatus
from .config import CONFIG_HTTP, API_ROOT_data, API_ROOT_STAGING_data, API_ROOT_ops
from .data import BaseDatastore
from .helpers import SSLCertT
# Function(method, path) -> bool
HandlerConditionT = Callable[[str, str], bool]
HandlerConditionT = Callable[[str, str], bool] # Function(method, path) -> bool
HttpHandlerT = Callable[[OrderedDict, Optional[str], SSLCertT], handlers.HttpResponse]
h2_handlers = None # type: HandlerList
h2_handlers = None # type: HttpHandlerList
class RequestData:
......@@ -27,18 +29,18 @@ class RequestData:
self.data_overflow = False
class HandlerList:
class HttpHandlerList:
def __init__(self):
self.handlers = [] # type: List[Tuple[HandlerConditionT, Callable]]
self.default_handler = None # type: Callable
self.handlers = [] # type: List[Tuple[HandlerConditionT, HttpHandlerT]]
self.default_handler = None # type: HttpHandlerT
def register_handler(self, condition: HandlerConditionT, handler: Callable):
def register_handler(self, condition: HandlerConditionT, handler: HttpHandlerT):
self.handlers.append((condition, handler))
def register_default_handler(self, handler: Callable):
def register_default_handler(self, handler: HttpHandlerT):
self.default_handler = handler
def get_handler(self, method: str, path: str) -> Callable:
def get_handler(self, method: str, path: str) -> HttpHandlerT:
for h in self.handlers:
if h[0](method, path):
return h[1]
......@@ -51,7 +53,7 @@ class H2Protocol(asyncio.Protocol):
self.conn = H2Connection(client_side=False)
self.transport = None
self.stream_data = {} # type: Dict[int, RequestData]
self.client_cert = None # type: Dict[str, Any]
self.client_cert = None # type: SSLCertT
def connection_made(self, transport: asyncio.Transport):
self.transport = transport
......@@ -59,19 +61,6 @@ class H2Protocol(asyncio.Protocol):
self.transport.write(self.conn.data_to_send())
self.client_cert = self.transport.get_extra_info('peercert')
def send_empty(self, stream_id: int, status_code: str, status_msg: str, status_in_body: bool = True):
response = status_code + " " + status_msg + "\n" if status_in_body else ""
response_bytes = response.encode()
response_headers = (
(":status", status_code),
("content-type", "text/plain"),
("content-length", len(response_bytes)),
("server", CONFIG_HTTP["SERVER_NAME"]),
)
self.conn.send_headers(stream_id, response_headers)
self.conn.send_data(stream_id, response_bytes, end_stream=True)
def data_received(self, data: bytes):
events = self.conn.receive_data(data)
for event in events:
......@@ -98,22 +87,22 @@ class H2Protocol(asyncio.Protocol):
try:
request_data = self.stream_data.pop(event.stream_id)
except KeyError:
self.send_empty(event.stream_id, "400", "Bad Request")
self.send_response(HttpResponse.empty(HttpStatus.BadRequest), event.stream_id)
else:
if request_data.data_overflow:
self.send_empty(event.stream_id, "406", "Not Acceptable")
self.send_response(HttpResponse.empty(HttpStatus.NotAcceptable), event.stream_id)
else:
headers = request_data.headers
body = request_data.data.getvalue().decode('utf-8')
http_method = headers[":method"]
if http_method in ("GET", "DELETE"):
self.handle_get_delete(headers, event.stream_id)
self.run_request_handler(headers, event.stream_id, None)
elif http_method in ("PUT", "POST"):
self.handle_put_post(headers, event.stream_id, body)
body = request_data.data.getvalue().decode('utf-8')
self.run_request_handler(headers, event.stream_id, body)
else:
warn("Unknown http method \"{}\"".format(headers[":method"]))
self.send_empty(event.stream_id, "405", "Method Not Allowed")
self.send_response(HttpResponse.empty(HttpStatus.MethodNotAllowed), event.stream_id)
# elif isinstance(event, RemoteSettingsChanged):
# changed_settings = {}
# for s in event.changed_settings.items():
......@@ -126,31 +115,50 @@ class H2Protocol(asyncio.Protocol):
if dts:
self.transport.write(dts)
def handle_put_post(self, headers: OrderedDict, stream_id: int, data: str):
# Handle PUT, POST
# Find and run handler for specific URI and HTTP method
def run_request_handler(self, headers: OrderedDict, stream_id: int, data: Optional[str]):
url_path = headers[":path"].split("?")[0]
h = h2_handlers.get_handler(headers[":method"], url_path)
if h:
h(self, stream_id, headers, data)
if not h:
self.send_response(HttpResponse.empty(HttpStatus.BadRequest), stream_id)
else:
self.send_empty(stream_id, "400", "Bad Request")
# Run handler and send HTTP response
resp = h(headers, data, self.client_cert)
self.send_response(resp, stream_id)
def send_response(self, resp: HttpResponse, stream_id: int):
resp_headers = (
(':status', resp.status_code),
('content-type', resp.content_type),
('content-length', len(resp.data)),
('server', CONFIG_HTTP["SERVER_NAME"]),
)
def handle_get_delete(self, headers: OrderedDict, stream_id: int):
# Handle GET, DELETE
url_path = headers[":path"].split("?")[0]
if resp.extra_headers:
resp_headers_od = OrderedDict(resp_headers)
resp_headers_od.update(resp.extra_headers)
resp_headers = resp_headers_od.items()
h = h2_handlers.get_handler(headers[":method"], url_path)
if h:
h(self, stream_id, headers)
self.conn.send_headers(stream_id, resp_headers)
# Do this for optimization
if len(resp.data) > self.conn.max_outbound_frame_size:
def split_arr(arr, chunk_size):
for i in range(0, len(arr), chunk_size):
yield arr[i:i + chunk_size]
for data_chunk in split_arr(resp.data, self.conn.max_outbound_frame_size):
self.conn.send_data(stream_id, data_chunk, end_stream=False)
self.conn.send_data(stream_id, bytes(), end_stream=True)
else:
self.send_empty(stream_id, "400", "Bad Request")
self.conn.send_data(stream_id, resp.data, end_stream=True)
class RestServer:
def __init__(self):
# HTTP server init
self.http_handlers = HandlerList()
self.http_handlers = HttpHandlerList()
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.options |= (ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_COMPRESSION)
ssl_context.load_cert_chain(certfile=CONFIG_HTTP["SERVER_SSL_CERT"], keyfile=CONFIG_HTTP["SERVER_SSL_PRIVKEY"])
......@@ -166,7 +174,12 @@ class RestServer:
self.loop = asyncio.get_event_loop()
# Each client connection will create a new H2Protocol instance
listener = self.loop.create_server(H2Protocol, "127.0.0.1", CONFIG_HTTP["PORT"], ssl=ssl_context)
listener = self.loop.create_server(
H2Protocol,
"127.0.0.1" if CONFIG_HTTP["LISTEN_LOCALHOST_ONLY"] else "",
CONFIG_HTTP["PORT"],
ssl=ssl_context
)
self.server = self.loop.run_until_complete(listener)
def register_api_handlers(self, datastore: BaseDatastore):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment