Split to mutiple files

parent 9f0eb425
#!/usr/bin/env python
import json
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography import x509
from certapi import app
from certapi.crypto import create_random_sid, create_random_nonce, key_match
from certapi.exceptions import InvalidAuthStateError, InvalidSessionError, InvalidParamError
from certapi import validators
DELAY_GET_SESSION_EXISTS = 10
DELAY_AUTH = 10
DELAY_AUTH_AGAIN = 10
AVAIL_REQUEST_TYPES = {"get_cert", "auth"}
AVAIL_FLAGS = {"renew"}
AVAIL_HASHES = {
hashes.SHA224,
hashes.SHA256,
hashes.SHA384,
hashes.SHA512,
}
SESSION_PARAMS = {
"auth_type",
"nonce",
"digest",
"csr_str",
"flags",
}
AUTH_STATE_PARAMS = {
"status",
}
class InvalidParamError(Exception):
pass
class InvalidSessionError(Exception):
pass
class InvalidAuthStateError(Exception):
pass
def validate_sn_atsha(sn):
if len(sn) != 16:
raise InvalidParamError("SN has invalid length.")
if sn[0:5] != "00000":
raise InvalidParamError("SN has invalid format.")
try:
sn_value = int(sn, 16)
except ValueError:
raise InvalidParamError("SN has invalid format.")
if sn_value % 11 != 0:
raise InvalidParamError("SN has invalid format.")
sn_validators = {
"atsha204": validate_sn_atsha,
}
def validate_csr_common_name(csr, identity):
common_names = csr.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
if len(common_names) != 1:
raise InvalidParamError("CSR has not exactly one CommonName")
common_name = common_names[0].value
if common_name != identity:
raise InvalidParamError("CSR CommonName ({}) does not match desired identity".format(common_name))
def validate_csr_hash(csr):
h = csr.signature_hash_algorithm
if type(h) not in AVAIL_HASHES:
raise InvalidParamError("CSR is signed with not allowed hash ({})".format(h.name))
def validate_csr_signature(csr):
if not csr.is_signature_valid:
raise InvalidParamError("Request signature is not valid")
def csr_from_str(csr_str):
try:
# construct x509 request from PEM string
csr_data = bytes(csr_str, encoding='utf-8')
csr = x509.load_pem_x509_csr(
data=csr_data,
backend=default_backend()
)
except (UnicodeEncodeError, ValueError):
raise InvalidParamError("Invalid CSR format")
return csr
def validate_csr(csr, sn):
csr = csr_from_str(csr)
validate_csr_common_name(csr, sn)
validate_csr_hash(csr)
validate_csr_signature(csr)
def validate_flags(flags):
for flag in flags:
if flag not in AVAIL_FLAGS:
raise InvalidParamError("Flag not available: {}".format(flag))
def validate_req_type(req_type):
if req_type not in AVAIL_REQUEST_TYPES:
raise InvalidParamError("Invalid request type: {}".format(req_type))
def validate_sid(sid):
if sid == "":
return
if (len(sid) != 64 or not sid.islower):
raise InvalidParamError("Bad format of sid : {}".format(sid))
try:
sid = int(sid, 16)
except ValueError:
raise InvalidParamError("Bad format of sid : {}".format(sid))
def validate_digest(digest):
if (len(digest) != 64 or not digest.islower):
raise InvalidParamError("Bad format of digest : {}".format(digest))
try:
digest = int(digest, 16)
except ValueError:
raise InvalidParamError("Bad format of digest : {}".format(digest))
def validate_auth_type(auth_type):
if auth_type not in sn_validators:
raise InvalidParamError("Invalid auth type: {}".format(auth_type))
def create_random_nonce():
return os.urandom(32).hex()
def create_random_sid():
return os.urandom(32).hex()
def get_session_key(sn, sid):
return "session:{}:{}".format(sn, sid)
......@@ -161,20 +22,6 @@ def get_cert_key(sn):
return "certificate:{}".format(sn)
def check_session(session):
for param in SESSION_PARAMS:
if param not in session:
raise InvalidSessionError(param)
def check_auth_state(auth_state):
if len(auth_state) != len(AUTH_STATE_PARAMS):
raise InvalidAuthStateError()
for param in auth_state:
if param not in AUTH_STATE_PARAMS:
raise InvalidAuthStateError()
def create_auth_session(sn, sid, csr_str, flags, auth_type, r):
""" Certificate with matching private key not found in redis
"""
......@@ -198,15 +45,6 @@ def create_auth_session(sn, sid, csr_str, flags, auth_type, r):
}
def key_match(cert_bytes, csr_bytes):
""" Compare public keys of two cryptographic objects and return True if
they are the same, otherwise return False.
"""
cert = x509.load_pem_x509_certificate(cert_bytes, default_backend())
csr = x509.load_pem_x509_csr(csr_bytes, default_backend())
return cert.public_key().public_numbers() == csr.public_key().public_numbers()
def get_auth_state(sn, sid, r):
""" Get state of client authentication from Redis. If the state is failed
or missing, return fail info.
......@@ -230,7 +68,7 @@ def get_auth_state(sn, sid, r):
return (False, {"status": "error"})
try:
check_auth_state(auth_state)
validators.check_auth_state(auth_state)
except InvalidAuthStateError:
app.logger.error("Auth state ivalid for sn=%s, sid=%s", sn, sid)
return (False, {"status": "error"})
......@@ -300,7 +138,7 @@ def get_auth_session(sn, sid, r):
return (None, {"status": "error"})
try:
check_session(session)
validators.check_session(session)
except InvalidSessionError as e:
app.logger.error("Value missing in Redis session (%s) sn=%s, sid=%s", e, sn, sid)
return (None, {"status": "error"})
......@@ -338,7 +176,6 @@ def process_req_auth(sn, sid, digest, auth_type, r):
return status
app.logger.debug("Authentication session found open for sn=%s, sid=%s", sn, sid)
if session["auth_type"] != auth_type:
app.logger.debug("Authentication type does not match, sn=%s, sid=%s", sn, sid)
return {"status": "fail"}
......@@ -362,17 +199,17 @@ def process_request(req_json, r):
app.logger.warning("Request failure: not a valid json")
return {"status": "error"}
try:
validate_req_type(req_json["type"])
validate_auth_type(req_json["auth_type"])
validate_sn = sn_validators[req_json["auth_type"]]
validate_sn(req_json["sn"])
validate_sid(req_json["sid"])
validators.validate_req_type(req_json["type"])
validators.validate_auth_type(req_json["auth_type"])
validators.validate_sn = validators.sn_validators[req_json["auth_type"]]
validators.validate_sn(req_json["sn"])
validators.validate_sid(req_json["sid"])
if req_json["type"] == "get_cert":
validate_csr(req_json["csr"], req_json["sn"])
validate_flags(req_json["flags"])
validators.validate_csr(req_json["csr"], req_json["sn"])
validators.validate_flags(req_json["flags"])
elif req_json["type"] == "auth":
validate_digest(req_json["digest"])
validators.validate_digest(req_json["digest"])
except KeyError as e:
app.logger.warning("Request failure: parameter missing: %s", e)
......
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography import x509
from certapi.exceptions import InvalidParamError
AVAIL_HASHES = {
"sha224",
"sha256",
"sha384",
"sha512",
}
def get_common_names(csr):
return csr.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
def csr_from_str(csr_str):
try:
# construct x509 request from PEM string
csr_data = bytes(csr_str, encoding='utf-8')
csr = x509.load_pem_x509_csr(
data=csr_data,
backend=default_backend()
)
except (UnicodeEncodeError, ValueError):
raise InvalidParamError("Invalid CSR format")
return csr
def create_random_nonce():
return os.urandom(32).hex()
def create_random_sid():
return os.urandom(32).hex()
def key_match(cert_bytes, csr_bytes):
""" Compare public keys of two cryptographic objects and return True if
they are the same, otherwise return False.
"""
cert = x509.load_pem_x509_certificate(cert_bytes, default_backend())
csr = x509.load_pem_x509_csr(csr_bytes, default_backend())
return cert.public_key().public_numbers() == csr.public_key().public_numbers()
class CertAPIError(Exception):
pass
class InvalidParamError(CertAPIError):
pass
class InvalidSessionError(CertAPIError):
pass
class InvalidAuthStateError(CertAPIError):
pass
from certapi.crypto import AVAIL_HASHES, get_common_names, csr_from_str
from certapi.exceptions import InvalidParamError, InvalidAuthStateError, InvalidSessionError
AVAIL_REQUEST_TYPES = {"get_cert", "auth"}
AVAIL_FLAGS = {"renew"}
SESSION_PARAMS = {
"auth_type",
"nonce",
"digest",
"csr_str",
"flags",
}
AUTH_STATE_PARAMS = {
"status",
}
def validate_sn_atsha(sn):
if len(sn) != 16:
raise InvalidParamError("SN has invalid length.")
if sn[0:5] != "00000":
raise InvalidParamError("SN has invalid format.")
try:
sn_value = int(sn, 16)
except ValueError:
raise InvalidParamError("SN has invalid format.")
if sn_value % 11 != 0:
raise InvalidParamError("SN has invalid format.")
sn_validators = {
"atsha204": validate_sn_atsha,
}
def validate_csr_common_name(csr, identity):
common_names = get_common_names(csr)
if len(common_names) != 1:
raise InvalidParamError("CSR has not exactly one CommonName")
common_name = common_names[0].value
if common_name != identity:
raise InvalidParamError("CSR CommonName ({}) does not match desired identity".format(common_name))
def validate_csr_hash(csr):
h = csr.signature_hash_algorithm.name
if h not in AVAIL_HASHES:
raise InvalidParamError("CSR is signed with not allowed hash ({})".format(h))
def validate_csr_signature(csr):
if not csr.is_signature_valid:
raise InvalidParamError("Request signature is not valid")
def validate_csr(csr, sn):
csr = csr_from_str(csr)
validate_csr_common_name(csr, sn)
validate_csr_hash(csr)
validate_csr_signature(csr)
def validate_flags(flags):
for flag in flags:
if flag not in AVAIL_FLAGS:
raise InvalidParamError("Flag not available: {}".format(flag))
def validate_req_type(req_type):
if req_type not in AVAIL_REQUEST_TYPES:
raise InvalidParamError("Invalid request type: {}".format(req_type))
def validate_sid(sid):
if sid == "":
return
if (len(sid) != 64 or not sid.islower()):
raise InvalidParamError("Bad format of sid: {}".format(sid))
try:
sid = int(sid, 16)
except ValueError:
raise InvalidParamError("Bad format of sid: {}".format(sid))
def validate_digest(digest):
if len(digest) != 64:
raise InvalidParamError("Bad format of digest: {}".format(digest))
try:
digest = int(digest, 16)
except ValueError:
raise InvalidParamError("Bad format of digest: {}".format(digest))
def validate_auth_type(auth_type):
if auth_type not in sn_validators:
raise InvalidParamError("Invalid auth type: {}".format(auth_type))
def check_session(session):
for param in SESSION_PARAMS:
if param not in session:
raise InvalidSessionError(param)
def check_auth_state(auth_state):
if len(auth_state) != len(AUTH_STATE_PARAMS):
raise InvalidAuthStateError()
for param in auth_state:
if param not in AUTH_STATE_PARAMS:
raise InvalidAuthStateError()
#!/usr/bin/env python
import json
import redis
......@@ -7,7 +6,7 @@ from flask import jsonify
from flask import g
from certapi import app
from certapi import certificator
from certapi.authentication import process_request
def get_redis():
......@@ -29,6 +28,6 @@ def request_view():
# request.data is class bytes
req_json = request.get_json() # class dict
log_debug_json("Incomming connection", req_json)
reply = certificator.process_request(req_json, get_redis())
reply = process_request(req_json, get_redis())
log_debug_json("Reply", reply)
return jsonify(reply)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment