diff --git a/api.py b/api.py index c30edfd..391fdf3 100644 --- a/api.py +++ b/api.py @@ -16,9 +16,6 @@ import mkey -class BadInputError(ValueError): - pass - app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") @@ -26,10 +23,12 @@ async def get_mkey(platform: Optional[str] = None, month: Optional[int] = None, generator = mkey.mkey_generator(debug=False) if platform not in ["RVL", "TWL", "CTR", "WUP", "HAC"]: - raise BadInputError(f"{platform} is an invalid platform.") + raise mkey.InvalidInputError(f"{platform} is an invalid platform.") master_key = None try: master_key = generator.generate(inquiry=inquiry, month=month, day=day, device=platform, aux=aux) + except mkey.InvalidInputError as e: + raise mkey.InvalidInputError(str(e)) except ValueError as e: raise ValueError(str(e)) return str(master_key) @@ -39,7 +38,7 @@ async def get_mkey(platform: Optional[str] = None, month: Optional[int] = None, async def api(platform: Optional[str] = None, month: Optional[int] = None, day: Optional[int] = None, inquiry: Optional[str] = None, aux: Optional[str] = None): try: ret = await get_mkey(platform, month, day, inquiry, aux) - except BadInputError as e: + except mkey.InvalidInputError as e: raise HTTPException(status_code=400, detail=str(e)) except ValueError as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/mkey.py b/mkey.py index ed847a0..be9d74c 100644 --- a/mkey.py +++ b/mkey.py @@ -31,6 +31,11 @@ try: import hexdump except ImportError: pass + +class InvalidInputError(ValueError): + pass + + class mkey_generator(): __props = { "RVL": { @@ -172,7 +177,7 @@ def _detect_algorithm(self, device, inquiry): if "v0" in algorithms: return "v0" else: - raise ValueError("v0 algorithm not supported by %s." % device) + raise InvalidInputError("v0 algorithm not supported by %s." % device) elif len(inquiry) == 10: version = int((int(inquiry) / 10000000) % 100) @@ -183,14 +188,14 @@ def _detect_algorithm(self, device, inquiry): elif "v3" in algorithms: return "v3" else: - raise ValueError("v1/v2/v3 algorithms not supported by %s." % device) + raise InvalidInputError("v1/v2/v3 algorithms not supported by %s." % device) elif len(inquiry) == 6: if "v4" in algorithms: return "v4" else: - raise ValueError("v4 algorithm not supported by %s." % device) + raise InvalidInputError("v4 algorithm not supported by %s." % device) else: - raise ValueError("Inquiry number must be 6, 8 or 10 digits.") + raise InvalidInputError("Inquiry number must be 6, 8 or 10 digits.") # CRC-32 implementation (v0). def _calculate_crc(self, poly, xorout, addout, inbuf): @@ -260,7 +265,7 @@ def _generate_v1_v2(self, props, inquiry, month, day): # be a guaranteed set of regions available. # if region not in props["regions"]: - raise ValueError("%s is an invalid region for console %s." % + raise InvalidInputError("%s is an invalid region for console %s." % (region, props["device"])) # @@ -373,10 +378,10 @@ def _generate_v3_v4(self, props, inquiry, aux = None): raise ValueError("v3/v4 attempted, but data directory doesn't exist or was not specified.") if algorithm == "v4" and not aux: - raise ValueError("v4 attempted, but no auxiliary string (device ID required).") + raise InvalidInputError("v4 attempted, but no auxiliary string (device ID required).") if algorithm == "v4" and len(aux) != 16: - raise ValueError("v4 attempted, but auxiliary string (device ID) of invalid length.") + raise InvalidInputError("v4 attempted, but auxiliary string (device ID) of invalid length.") if algorithm == "v4": version = int((inquiry / 10000) % 100) @@ -440,7 +445,7 @@ def _generate_v3_v4(self, props, inquiry, aux = None): def generate(self, inquiry, month = None, day = None, aux = None, device = None): inquiry = inquiry.replace(" ", "") if not inquiry.isdigit(): - raise ValueError("Inquiry string must represent a decimal number.") + raise InvalidInputError("Inquiry string must represent a decimal number.") if month is None: month = datetime.date.today().month @@ -448,14 +453,14 @@ def generate(self, inquiry, month = None, day = None, aux = None, device = None) day = datetime.date.today().day if month < 1 or month > 12: - raise ValueError("Month must be between 1 and 12.") + raise InvalidInputError("Month must be between 1 and 12.") if day < 1 or day > 31: - raise ValueError("Day must be between 1 and 31.") + raise InvalidInputError("Day must be between 1 and 31.") if not device: device = self.default_device if device not in self.devices: - raise ValueError("Unsupported device: %s." % device) + raise InvalidInputError("Unsupported device: %s." % device) # We can glean information about the required algorithm from the inquiry number. algorithm = self._detect_algorithm(device, inquiry)