diff --git a/autosubmit_api/app.py b/autosubmit_api/app.py index eb747403fd5585285e0f0059fa6571e069bffd8e..5e2f86b3d5b0189f65bf51d3f40a07b2210fd8de 100644 --- a/autosubmit_api/app.py +++ b/autosubmit_api/app.py @@ -1,58 +1,53 @@ import os import sys -import requests -from flask_cors import CORS -from flask import Flask -from autosubmit_api.bgtasks.scheduler import create_bind_scheduler -from autosubmit_api.blueprints.v3 import create_v3_blueprint -from autosubmit_api.blueprints.v4 import create_v4_blueprint +import time +from fastapi.responses import JSONResponse +from autosubmit_api import routers +from autosubmit_api.bgtasks.scheduler import create_scheduler from autosubmit_api.database import prepare_db from autosubmit_api.experiment import common_requests as CommonRequests -from autosubmit_api.logger import get_app_logger +from autosubmit_api.logger import get_app_logger, logger from autosubmit_api.config.basicConfig import APIBasicConfig from autosubmit_api.config import ( PROTECTION_LEVEL, CAS_LOGIN_URL, CAS_VERIFY_URL, + AS_API_ROOT_PATH, get_run_background_tasks_on_start, get_disable_background_tasks, ) -from autosubmit_api.views import handle_HTTP_exception, home -from werkzeug.exceptions import HTTPException +from fastapi import FastAPI, HTTPException as FastAPIHTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager +from autosubmit_api import __version__ as APIVersion +from autosubmit_api.middleware import HttpUrlModifyMiddleware +sys.path.insert(0, os.path.abspath(".")) -def create_app(): - """ - Autosubmit Flask application factory - This function initializes the application properly - """ - - sys.path.insert(0, os.path.abspath(".")) - - app = Flask(__name__) +scheduler = create_scheduler() - # CORS setup - CORS(app) - # Logger binding - app.logger = get_app_logger() - app.logger.info("PYTHON VERSION: " + sys.version) +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Beware this lifespan will run on every worker + """ + # Startup + yield + # Shutdown - # Enforce Language Locale - CommonRequests.enforceLocal(app.logger) - requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS += "HIGH:!DH:!aNULL" - try: - requests.packages.urllib3.contrib.pyopenssl.DEFAULT_SSL_CIPHER_LIST += ( - "HIGH:!DH:!aNULL" - ) - except AttributeError: - app.logger.warning("No pyopenssl support used / needed / available") +def create_app(): + """ + Create the FastAPI app. It will run this only once before the server starts when using multiple workers. + """ + logger.info("PYTHON VERSION: " + sys.version) + CommonRequests.enforceLocal(logger) # Initial read config APIBasicConfig.read() - app.logger.debug("API Basic config: " + str(APIBasicConfig().props())) - app.logger.debug( + logger.debug("API Basic config: " + str(APIBasicConfig().props())) + logger.debug( "Env Config: " + str( { @@ -61,6 +56,7 @@ def create_app(): "CAS_VERIFY_URL": CAS_VERIFY_URL, "DISABLE_BACKGROUND_TASKS": get_disable_background_tasks(), "RUN_BACKGROUND_TASKS_ON_START": get_run_background_tasks_on_start(), + "AS_API_ROOT_PATH": AS_API_ROOT_PATH, } ) ) @@ -68,22 +64,90 @@ def create_app(): # Prepare DB prepare_db() - # Background Scheduler - create_bind_scheduler(app) + # Initial background tasks + scheduler.start() + + return FastAPI( + root_path=AS_API_ROOT_PATH, + lifespan=lifespan, + redirect_slashes=True, + title="Autosubmit API", + version=APIVersion, + license_info={ + "name": "GNU General Public License", + "url": "https://www.gnu.org/licenses/gpl-3.0.html", + }, + ) + + +app = create_app() + + +# Exception handlers - ################################ ROUTES ################################ - app.route("/")(home) +@app.exception_handler(FastAPIHTTPException) +async def http_exception_handler(request: Request, exc: FastAPIHTTPException): + return JSONResponse( + content={"error": True, "error_message": exc.detail}, + status_code=exc.status_code, + ) + + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + return JSONResponse( + content={"error": True, "error_message": "An unexpected error occurred."}, + status_code=500, + ) + + +# Middlewares + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.middleware("http") +async def log_runtime(request: Request, call_next): + logger = get_app_logger() + start_time = time.time() + try: + path = request.url.path + if request.url.query: + path += "?" + request.url.query + method = request.method + except Exception: + path = "" + method = "" + logger.info("\033[94m{} {}|RECEIVED\033[0m".format(method, path)) + try: + response = await call_next(request) + except Exception as exc: + logger.error( + "\033[91m{} {}|ERROR|Exception msg: {}\033[0m".format( + method, path, str(exc) + ) + ) + raise exc + logger.info( + "\033[92m{} {}|RTIME|{:.3f}s\033[0m".format( + method, path, (time.time() - start_time) + ) + ) + return response + - v3_blueprint = create_v3_blueprint() - app.register_blueprint( - v3_blueprint, name="root" - ) # Add v3 to root but will be DEPRECATED - app.register_blueprint(v3_blueprint, url_prefix="/v3") +# NOTE: Middleware is executed in the inverse order of the order they are added. +# So, the HttpUrlModifyMiddleware should be added at the end. +app.add_middleware(HttpUrlModifyMiddleware) - v4_blueprint = create_v4_blueprint() - app.register_blueprint(v4_blueprint, url_prefix="/v4") - app.register_error_handler(HTTPException, handle_HTTP_exception) +# Routers - return app +app.include_router(routers.router) diff --git a/autosubmit_api/auth/__init__.py b/autosubmit_api/auth/__init__.py index 92ee07ab2b97200ac1eec43b0aae89fb6f85fe73..c4092e046794480e90aa1bde01c7de8b03aa6c6f 100644 --- a/autosubmit_api/auth/__init__.py +++ b/autosubmit_api/auth/__init__.py @@ -1,6 +1,7 @@ -from functools import wraps from http import HTTPStatus -from flask import request +from typing import Annotated +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer import jwt from autosubmit_api.logger import logger from autosubmit_api import config @@ -26,47 +27,41 @@ def _parse_protection_level_env(_var): return ProtectionLevels.ALL -def with_auth_token( - threshold=ProtectionLevels.ALL, response_on_fail=True, raise_on_fail=False -): +security = HTTPBearer(auto_error=False) + + +def auth_token_dependency(threshold=ProtectionLevels.ALL, raise_on_fail=True): """ - Decorator that validates the Authorization token in a request. + FastAPI Dependency that validates the Authorization token in a request. It adds the `user_id` variable inside the arguments of the wrapped function. :param threshold: The minimum PROTECTION_LEVEL that needs to be set to trigger a *_on_fail - :param response_on_fail: if `True` will return a Flask response on fail :param raise_on_fail: if `True` will raise an exception on fail - :raises AppAuthError: if raise_on_fail=True and decoding fails """ - def decorator(func): - @wraps(func) - def inner_wrapper(*args, **kwargs): - try: - current_token = request.headers.get("Authorization") - jwt_token = jwt.decode(current_token, config.JWT_SECRET, config.JWT_ALGORITHM) - except Exception as exc: - error_msg = "Unauthorized" - if isinstance(exc, jwt.ExpiredSignatureError): - error_msg = "Expired token" - auth_level = _parse_protection_level_env(config.PROTECTION_LEVEL) - if threshold <= auth_level: # If True, will trigger *_on_fail - if raise_on_fail: - raise AppAuthError(error_msg) - if response_on_fail: - return { - "error": True, - "error_message": error_msg, - }, HTTPStatus.UNAUTHORIZED - jwt_token = {"user_id": None} - - user_id = jwt_token.get("user_id", None) - logger.debug("decorator user_id: " + str(user_id)) - kwargs["user_id"] = user_id - - return func(*args, **kwargs) - - return inner_wrapper - - return decorator + async def dependency( + credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], + ): + try: + current_token = credentials.credentials + jwt_token = jwt.decode( + current_token, config.JWT_SECRET, config.JWT_ALGORITHM + ) + except Exception as exc: + error_msg = "Unauthorized" + if isinstance(exc, jwt.ExpiredSignatureError): + error_msg = "Expired token" + auth_level = _parse_protection_level_env(config.PROTECTION_LEVEL) + if threshold <= auth_level: # If True, will trigger *_on_fail + if raise_on_fail: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, detail=error_msg + ) + jwt_token = {"user_id": None} + + user_id = jwt_token.get("user_id", None) + logger.debug("user_id: " + str(user_id)) + return user_id + + return dependency diff --git a/autosubmit_api/bgtasks/scheduler.py b/autosubmit_api/bgtasks/scheduler.py index bbca554eed925601027a0028f5f111d5247ec4d3..d7fa3752ff3dab6ab32a856b31b63fd6a2ef15b8 100644 --- a/autosubmit_api/bgtasks/scheduler.py +++ b/autosubmit_api/bgtasks/scheduler.py @@ -1,5 +1,7 @@ from typing import List -from flask_apscheduler import APScheduler +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.interval import IntervalTrigger +from apscheduler.triggers.cron import CronTrigger from autosubmit_api.bgtasks.bgtask import ( BackgroundTaskTemplate, PopulateDetailsDB, @@ -20,18 +22,24 @@ REGISTERED_TASKS: List[BackgroundTaskTemplate] = [ ] -def create_bind_scheduler(app): - scheduler = APScheduler() - scheduler.init_app(app) - scheduler.start() +def create_scheduler(): + scheduler = BackgroundScheduler() if not get_disable_background_tasks(): for task in REGISTERED_TASKS: - scheduler.add_job( - task.id, - with_log_run_times(logger, task.id, catch_exc=True)(task.run), - **task.trigger_options - ) + trigger_options = task.trigger_options.copy() + + job = with_log_run_times(logger, task.id, catch_exc=True)(task.run) + trigger_type = trigger_options.pop("trigger") + trigger = None + if trigger_type == "interval": + trigger = IntervalTrigger(**trigger_options) + elif trigger_type == "cron": + trigger = CronTrigger(**trigger_options) + else: + raise ValueError(f"Invalid trigger type {trigger_type}") + + scheduler.add_job(id=task.id, name=task.id, func=job, trigger=trigger) logger.info( "Background tasks: " + str([str(task) for task in scheduler.get_jobs()]) @@ -40,6 +48,6 @@ def create_bind_scheduler(app): if get_run_background_tasks_on_start(): logger.info("Starting background tasks on app init before serving...") for task in REGISTERED_TASKS: - scheduler.run_job(task.id) + scheduler.get_job(task.id).func() return scheduler diff --git a/autosubmit_api/blueprints/v3.py b/autosubmit_api/blueprints/v3.py deleted file mode 100644 index 1cfedb2ddec99a11bfe8419a5102f0ef9622f4f5..0000000000000000000000000000000000000000 --- a/autosubmit_api/blueprints/v3.py +++ /dev/null @@ -1,68 +0,0 @@ -from flask import Blueprint -from autosubmit_api.views import v3 as v3_views - - -def create_v3_blueprint(): - blueprint = Blueprint("v3", __name__) - - blueprint.route("/login")(v3_views.login) - blueprint.route("/tokentest", methods=["GET", "POST"])(v3_views.test_token) - blueprint.route("/updatedesc", methods=["GET", "POST"])(v3_views.update_description) - blueprint.route("/cconfig/", methods=["GET"])( - v3_views.get_current_configuration - ) - blueprint.route("/expinfo/", methods=["GET"])(v3_views.exp_info) - blueprint.route("/expcount/", methods=["GET"])(v3_views.exp_counters) - blueprint.route( - "/searchowner///", - methods=["GET"], - )(v3_views.search_owner) - blueprint.route("/searchowner/", methods=["GET"])( - v3_views.search_owner - ) - blueprint.route( - "/search///", methods=["GET"] - )(v3_views.search_expid) - blueprint.route("/search/", methods=["GET"])(v3_views.search_expid) - blueprint.route("/running/", methods=["GET"])(v3_views.search_running) - blueprint.route("/runs/", methods=["GET"])(v3_views.get_runs) - blueprint.route("/ifrun/", methods=["GET"])(v3_views.get_if_running) - blueprint.route("/logrun/", methods=["GET"])( - v3_views.get_running_detail - ) - blueprint.route("/summary/", methods=["GET"])(v3_views.get_expsummary) - blueprint.route("/shutdown/")(v3_views.shutdown) - blueprint.route("/performance/", methods=["GET"])( - v3_views.get_exp_performance - ) - blueprint.route( - "/graph///", methods=["GET"] - )(v3_views.get_graph_format) - blueprint.route("/tree/", methods=["GET"])(v3_views.get_exp_tree) - blueprint.route("/quick/", methods=["GET"])( - v3_views.get_quick_view_data - ) - blueprint.route("/exprun/", methods=["GET"])( - v3_views.get_experiment_run_log - ) - blueprint.route("/joblog/", methods=["GET"])( - v3_views.get_job_log_from_path - ) - blueprint.route("/pklinfo//", methods=["GET"])( - v3_views.get_experiment_pklinfo - ) - blueprint.route("/pkltreeinfo//", methods=["GET"])( - v3_views.get_experiment_tree_pklinfo - ) - blueprint.route( - "/stats///" - )(v3_views.get_experiment_statistics) - blueprint.route("/history//")( - v3_views.get_exp_job_history - ) - blueprint.route("/rundetail//")( - v3_views.get_experiment_run_job_detail - ) - blueprint.route("/filestatus/")(v3_views.get_file_status) - - return blueprint diff --git a/autosubmit_api/blueprints/v4.py b/autosubmit_api/blueprints/v4.py deleted file mode 100644 index dd619800c31115bf1c362929f81c0853a8d92ba4..0000000000000000000000000000000000000000 --- a/autosubmit_api/blueprints/v4.py +++ /dev/null @@ -1,98 +0,0 @@ -from flask import Blueprint -from autosubmit_api.views import v4 as v4_views - - -def create_v4_blueprint(): - blueprint = Blueprint("v4", __name__) - - # TODO Uncomment endpoints as they are ready to be published - - blueprint.add_url_rule( - "/auth/cas/v2/login", view_func=v4_views.CASV2Login.as_view("CASV2Login") - ) - blueprint.add_url_rule( - "/auth/oauth2/github/login", - view_func=v4_views.GithubOauth2Login.as_view("GithubOauth2Login"), - ) - blueprint.add_url_rule( - "/auth/verify-token", view_func=v4_views.AuthJWTVerify.as_view("AuthJWTVerify") - ) - - # blueprint.route("/experiments//description", methods=["PUT"])( - # v4_views.experiment_description_view - # ) - # blueprint.route("/experiments//info")(v3_views.exp_info) - # blueprint.route("/experiments//status-counters")( - # v3_views.exp_counters - # ) - - blueprint.add_url_rule( - "/experiments", view_func=v4_views.ExperimentView.as_view("ExperimentView") - ) - - blueprint.add_url_rule( - "/experiments/", - view_func=v4_views.ExperimentDetailView.as_view("ExperimentDetailView"), - ) - - blueprint.add_url_rule( - "/experiments//jobs", - view_func=v4_views.ExperimentJobsView.as_view("ExperimentJobsView"), - ) - - blueprint.add_url_rule( - "/experiments//wrappers", - view_func=v4_views.ExperimentWrappersView.as_view("ExperimentWrappersView"), - ) - blueprint.add_url_rule( - "/experiments//filesystem-config", - view_func=v4_views.ExperimentFSConfigView.as_view("ExperimentFSConfigView"), - ) - blueprint.add_url_rule( - "/experiments//runs", - view_func=v4_views.ExperimentRunsView.as_view("ExperimentRunsView"), - ) - blueprint.add_url_rule( - "experiments//runs//config", - view_func=v4_views.ExperimentRunConfigView.as_view("ExperimentRunConfigView"), - ) - - # blueprint.route("/experiments//runs")(v3_views.get_runs) - # blueprint.route("/experiments//check-running")( - # v3_views.get_if_running - # ) - # blueprint.route("/experiments//running-detail")( - # v3_views.get_running_detail - # ) - # blueprint.route("/experiments//summary")(v3_views.get_expsummary) - - # blueprint.route("/routes//shutdown")(v3_views.shutdown) - - # blueprint.route("/experiments//performance")( - # v3_views.get_exp_performance - # ) - # blueprint.route("/experiments//graph")(v4_views.exp_graph_view) - # blueprint.route("/experiments//tree")(v3_views.get_exp_tree) - # blueprint.route("/experiments//quick")(v3_views.get_quick_view_data) - - # blueprint.route("/experiments//run-log")( - # v3_views.get_experiment_run_log - # ) - # blueprint.route("/job-logs/")(v3_views.get_job_log_from_path) - - # blueprint.route("/experiments//graph-diff")( - # v3_views.get_experiment_pklinfo - # ) - # blueprint.route("/experiments//tree-diff")( - # v3_views.get_experiment_tree_pklinfo - # ) - # blueprint.route("/experiments//stats")(v4_views.exp_stats_view) - # blueprint.route("/experiments//jobs//history")( - # v3_views.get_exp_job_history - # ) - # blueprint.route("/experiments//runs/")( - # v3_views.get_experiment_run_job_detail - # ) - # blueprint.route("/filestatus")(v3_views.get_file_status) - - return blueprint diff --git a/autosubmit_api/cli.py b/autosubmit_api/cli.py index 4b3d1453b46f40afed8a17fbcf6201775b850c2d..48d8246c294acaac667a7a4f4174b6040ba21343 100644 --- a/autosubmit_api/cli.py +++ b/autosubmit_api/cli.py @@ -51,6 +51,7 @@ def start_app_gunicorn( "preload_app": True, "capture_output": True, "timeout": 600, + "worker_class": "uvicorn.workers.UvicornWorker" } if bind and len(bind) > 0: options["bind"] = bind @@ -77,7 +78,7 @@ def start_app_gunicorn( if keepalive and keepalive > 0: options["keepalive"] = keepalive - g_app = StandaloneApplication("autosubmit_api.app:create_app()", options) + g_app = StandaloneApplication("autosubmit_api.app:app", options) print("Starting with gunicorn options: " + str(g_app.options)) g_app.run() diff --git a/autosubmit_api/config/__init__.py b/autosubmit_api/config/__init__.py index e09cf1bd31477c07e90d932725610b7841a49f38..1f8010417b723dcff3bcdfc389100642f4062232 100644 --- a/autosubmit_api/config/__init__.py +++ b/autosubmit_api/config/__init__.py @@ -46,3 +46,5 @@ def get_disable_background_tasks(): "T", "true", ] # Default false + +AS_API_ROOT_PATH=os.environ.get("AS_API_ROOT_PATH", "") \ No newline at end of file diff --git a/autosubmit_api/logger.py b/autosubmit_api/logger.py index b5c4a8f1ad29cf21769cfcef44190a0d58bf979b..1ae2eb0a512ca5b1b3778db72fded43dbe4bae53 100644 --- a/autosubmit_api/logger.py +++ b/autosubmit_api/logger.py @@ -3,14 +3,12 @@ import logging import time import traceback -from flask import request - def with_log_run_times( _logger: logging.Logger, _tag: str = "", catch_exc: bool = False ): """ - Function decorator to log runtimes of the endpoints + Function decorator to log runtimes :param _logger: logger to use :param _tag: tag for the logs :param catch_exc: if True, will catch any Exception and not raise it @@ -20,27 +18,17 @@ def with_log_run_times( @wraps(func) def inner_wrapper(*args, **kwargs): start_time = time.time() - try: - path = request.full_path - method = request.method - except Exception: - path = "" - method = "" - _logger.info("{}|RECEIVED|{}{}".format(_tag, method, path)) + _logger.info("\033[94m{}|RECEIVED\033[0m".format(_tag)) response = None try: response = func(*args, **kwargs) except Exception as exc: - _logger.error("{}|ERROR|{}|Exception msg: {}".format(_tag, path, exc)) + _logger.error("\033[91m{}|ERROR|Exception msg: {}\033[0m".format(_tag, exc)) if catch_exc: _logger.error(traceback.format_exc()) else: raise exc - _logger.info( - "{}|RTIME|{}{}|{:.3f}".format( - _tag, method, path, (time.time() - start_time) - ) - ) + _logger.info("\033[92m{}|RTIME|{:.3f}\033[0m".format(_tag, (time.time() - start_time))) return response return inner_wrapper diff --git a/autosubmit_api/middleware.py b/autosubmit_api/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae5ce041b39edac5202d20aeb1abd51681d24ff --- /dev/null +++ b/autosubmit_api/middleware.py @@ -0,0 +1,20 @@ +import re + +from starlette.types import ASGIApp, Receive, Scope, Send + +repeated_quotes = re.compile(r"//+") + + +class HttpUrlModifyMiddleware: + """ + This http middleware modifies urls with repeated slashes to the cleaned up + versions of the urls without redirecting + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and repeated_quotes.search(scope["path"]): + scope["path"] = repeated_quotes.sub("/", scope["path"]) + await self.app(scope, receive, send) diff --git a/autosubmit_api/blueprints/__init__.py b/autosubmit_api/models/__init__.py similarity index 100% rename from autosubmit_api/blueprints/__init__.py rename to autosubmit_api/models/__init__.py diff --git a/autosubmit_api/models/experiment.py b/autosubmit_api/models/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..d869311c21505e1dca74412e5aa911792bc9a2a8 --- /dev/null +++ b/autosubmit_api/models/experiment.py @@ -0,0 +1,83 @@ +from typing import Annotated, List, Optional +from pydantic import BaseModel, Field + + +class BaseExperimentRun(BaseModel): + run_id: Annotated[int, Field(description="Run ID", example=1)] + start: Annotated[ + Optional[str], + Field(description="Start time of the run", example="2021-01-01T00:00:00Z"), + ] + finish: Annotated[ + Optional[str], + Field(description="Finish time of the run", example="2021-01-01T00:00:00Z"), + ] + + +class ExperimentSearchItem(BaseModel): + id: Annotated[int, Field(description="Numerical Experiment ID", example=3)] + name: Annotated[str, Field(description="Expid of the experiment", example="a002")] + user: Annotated[ + Optional[str], + Field(description="User who created the experiment", example="user"), + ] + description: Annotated[ + Optional[str], + Field(description="Description of the experiment", example="Description"), + ] + hpc: Annotated[ + Optional[str], + Field(description="HPC where the experiment was run", example="hpc"), + ] + version: Annotated[ + Optional[str], + Field(description="Version of Autosubmit of this experiment", example="4.1.0"), + ] + created: Annotated[ + Optional[str], + Field( + description="Creation time of the experiment", example="2023-10-16 15:21:51" + ), + ] + modified: Annotated[ + Optional[str], + Field( + description="Last modification time of the experiment", + example="2024-01-12T16:38:56", + ), + ] + status: Annotated[ + Optional[str], Field(description="Status of the experiment", example="RUNNING") + ] + completed: Annotated[ + Optional[int], Field(description="Number of completed jobs", example=0) + ] + total: Annotated[ + Optional[int], Field(description="Total number of jobs", example=8) + ] + submitted: Annotated[ + Optional[int], Field(description="Number of submitted jobs", example=0) + ] + queuing: Annotated[ + Optional[int], Field(description="Number of jobs in the queue", example=0) + ] + running: Annotated[ + Optional[int], Field(description="Number of jobs running", example=0) + ] + failed: Annotated[ + Optional[int], Field(description="Number of failed jobs", example=0) + ] + suspended: Annotated[ + Optional[int], Field(description="Number of suspended jobs", example=0) + ] + + +class BaseExperimentWrapper(BaseModel): + wrapper_name: Annotated[ + str, + Field( + description="Name of the wrapper", + example="a6zi_ASThread_17108816522924_528_10", + ), + ] + job_names: Annotated[List[str], Field(description="List of job names")] diff --git a/autosubmit_api/models/misc.py b/autosubmit_api/models/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e706dd5c5daa0d977bcf22ee004b07cabcc6ff31 --- /dev/null +++ b/autosubmit_api/models/misc.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel, Field +from typing import Annotated, List, Optional + + +class RouteInfo(BaseModel): + path: Annotated[str, Field(description="Path of the route", example="/")] + methods: List[Annotated[str, Field(description="HTTP methods", example="GET")]] + + +class PaginationInfo(BaseModel): + page: Annotated[int, Field(1, ge=1, description="Page number", example=1)] + page_size: Annotated[ + Optional[int], Field(description="Number of items per page", example=12) + ] + total_pages: Annotated[ + int, Field(1, ge=1, description="Total number of pages", example=1) + ] + page_items: Annotated[ + int, Field(ge=0, description="Number of items in the current page", example=1) + ] + total_items: Annotated[ + int, Field(ge=0, description="Total number of items", example=1) + ] diff --git a/autosubmit_api/models/requests.py b/autosubmit_api/models/requests.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b53cc50787515484929361e771e8b8803983a4 --- /dev/null +++ b/autosubmit_api/models/requests.py @@ -0,0 +1,18 @@ +from typing import Annotated, Literal, Optional +from pydantic import BaseModel, Field + +PAGINATION_LIMIT_DEFAULT = 12 + + +class ExperimentsSearchRequest(BaseModel): + query: Optional[str] = None + only_active: bool = True + owner: Optional[str] = None + exp_type: Optional[Literal["test", "operational", "experiment"]] = None + autosubmit_version: Optional[str] = None + + order_by: Optional[Literal["expid", "created", "description"]] = None + order_desc: bool = True + + page: Annotated[int, Field(ge=1, description="Page number", example=1)] = 1 + page_size: int = PAGINATION_LIMIT_DEFAULT diff --git a/autosubmit_api/models/responses.py b/autosubmit_api/models/responses.py new file mode 100644 index 0000000000000000000000000000000000000000..85093b01c1f1737ed31c216d01ee57eb870c3260 --- /dev/null +++ b/autosubmit_api/models/responses.py @@ -0,0 +1,50 @@ +from typing import List, Optional +from pydantic import BaseModel + +from autosubmit_api.database.models import PklJobModel +from autosubmit_api.models.experiment import ( + BaseExperimentRun, + BaseExperimentWrapper, + ExperimentSearchItem, +) +from autosubmit_api.models.misc import PaginationInfo, RouteInfo + + +class AuthResponse(BaseModel): + authenticated: bool + user: Optional[str] + + +class LoginResponse(AuthResponse): + token: Optional[str] + message: Optional[str] + + +class ExperimentRunsResponse(BaseModel): + runs: List[BaseExperimentRun] + + +class RoutesResponse(BaseModel): + routes: List[RouteInfo] + + +class ExperimentsSearchResponse(BaseModel): + experiments: List[ExperimentSearchItem] + pagination: PaginationInfo + + +class ExperimentJobsResponse(BaseModel): + jobs: List[PklJobModel] + + +class ExperimentFSConfigResponse(BaseModel): + config: dict + + +class ExperimentRunConfigResponse(BaseModel): + run_id: Optional[int] + config: dict + + +class ExperimentWrappersResponse(BaseModel): + wrappers: List[BaseExperimentWrapper] diff --git a/autosubmit_api/routers/__init__.py b/autosubmit_api/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbef66551f9637335d0910bab796c600d2649d41 --- /dev/null +++ b/autosubmit_api/routers/__init__.py @@ -0,0 +1,14 @@ +from fastapi import APIRouter +from autosubmit_api import __version__ as APIVersion +from autosubmit_api.routers import v4, v3 + +router = APIRouter() + + +@router.get("/", name="Home") +async def home(): + return {"name": "Autosubmit API", "version": APIVersion} + + +router.include_router(v4.router, prefix="/v4", tags=["v4"]) +router.include_router(v3.router, prefix="/v3", tags=["v3"]) \ No newline at end of file diff --git a/autosubmit_api/routers/v3/__init__.py b/autosubmit_api/routers/v3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b4af29038bd73995088f47b41a2f0037a1fb0f --- /dev/null +++ b/autosubmit_api/routers/v3/__init__.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +from autosubmit_api.routers.v3 import experiments, auth + +router = APIRouter() + +router.include_router(experiments.router) +router.include_router(auth.router) diff --git a/autosubmit_api/routers/v3/auth.py b/autosubmit_api/routers/v3/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..f6fb0decb03400392a4366b5575a2f484e53d52b --- /dev/null +++ b/autosubmit_api/routers/v3/auth.py @@ -0,0 +1,108 @@ +from datetime import datetime, timedelta, timezone +from http import HTTPStatus +from typing import Optional +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, RedirectResponse +import jwt +import requests +from autosubmit_api import config +from autosubmit_api.auth import ProtectionLevels, auth_token_dependency +from autosubmit_api.config.basicConfig import APIBasicConfig +from autosubmit_api.models.responses import LoginResponse +from autosubmit_api.logger import logger +from autosubmit_api.experiment import utils as Utiles + + +router = APIRouter() + + +@router.get("/tokentest", name="Test JWT token") +async def test_token( + user_id: Optional[str] = Depends( + auth_token_dependency(threshold=ProtectionLevels.NONE, raise_on_fail=False) + ), +) -> dict: + """ + Tests if a token is still valid + """ + return JSONResponse( + status_code=(HTTPStatus.OK if user_id else HTTPStatus.UNAUTHORIZED), + content={ + "isValid": True if user_id else False, + "message": "Unauthorized" if not user_id else None, + }, + ) + + +@router.get("/login", name="CAS Login") +async def login( + ticket: Optional[str] = None, + env: Optional[str] = None, + request: Request = None, +) -> LoginResponse: + APIBasicConfig.read() + ticket = ticket + environment = env + referrer = request.headers.get("Referer") + is_allowed = False + for allowed_client in APIBasicConfig.ALLOWED_CLIENTS: + if referrer and referrer.find(allowed_client) >= 0: + referrer = allowed_client + is_allowed = True + if is_allowed is False: + return JSONResponse( + { + "authenticated": False, + "user": None, + "token": None, + "message": "Your client is not authorized for this operation. The API admin needs to add your URL to the list of allowed clients.", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) + + target_service = "{}{}/login".format(referrer, environment) + if not ticket: + route_to_request_ticket = "{}?service={}".format( + config.CAS_LOGIN_URL, target_service + ) + logger.info("Redirected to: " + str(route_to_request_ticket)) + return RedirectResponse(url=route_to_request_ticket) + # can be used to target the test environment + # environment = environment if environment is not None else "autosubmitapp" + cas_verify_ticket_route = ( + config.CAS_VERIFY_URL + "?service=" + target_service + "&ticket=" + ticket + ) + response = requests.get(cas_verify_ticket_route) + user = None + if response: + user = Utiles.get_cas_user_from_xml(response.content) + logger.info("CAS verify ticket response: user %s", user) + if not user: + return JSONResponse( + { + "authenticated": False, + "user": None, + "token": None, + "message": "Can't verify user.", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) + else: # Login successful + payload = { + "user_id": user, + "sub": user, + "iat": int(datetime.now().timestamp()), + "exp": ( + datetime.now(timezone.utc) + + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) + ), + } + jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) + return JSONResponse( + { + "authenticated": True, + "user": user, + "token": f"Bearer {jwt_token}", + "message": "Token generated.", + } + ) diff --git a/autosubmit_api/routers/v3/experiments.py b/autosubmit_api/routers/v3/experiments.py new file mode 100644 index 0000000000000000000000000000000000000000..b1439eb1959408dfe3600b6d5479262e8f7b49bd --- /dev/null +++ b/autosubmit_api/routers/v3/experiments.py @@ -0,0 +1,438 @@ +from multiprocessing import Lock, Manager +import os +from typing import Literal, Optional +from fastapi import APIRouter, Depends, Path, Query +from autosubmit_api.builders.joblist_helper_builder import ( + JobListHelperBuilder, + JobListHelperDirector, +) +from autosubmit_api.logger import logger +from autosubmit_api.auth import auth_token_dependency +from autosubmit_api.database.db_common import ( + get_current_running_exp, + search_experiment_by_id, +) +from autosubmit_api.experiment import common_requests as CommonRequests +from autosubmit_api.performance.performance_metrics import PerformanceMetrics + + +router = APIRouter() + +# Multiprocessing setup +D = Manager().dict() +lock = Lock() + + +@router.get("/cconfig/{expid}", name="Get Experiment Configuration") +async def get_current_configuration( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get the current configuration of the experiment + """ + result = CommonRequests.get_current_configuration_by_expid(expid, user_id) + return result + + +@router.get("/expinfo/{expid}", name="Get Experiment Information") +async def get_exp_info( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get the information of the experiment + """ + result = CommonRequests.get_experiment_data(expid) + return result + + +@router.get("/expcount/{expid}", name="Get Experiment Counters") +async def get_exp_counters( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get the job status counters of the experiment + """ + result = CommonRequests.get_experiment_counters(expid) + return result + + +@router.get("/searchowner/{owner}", name="Search Owner") +async def search_owner( + owner: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Search for experiments by owner + """ + result = search_experiment_by_id( + query=None, owner=owner, exp_type=None, only_active=None + ) + return result + + +@router.get("/search/{expid}", name="Search Experiment by expid") +async def search_expid( + expid: str, + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Search for experiments by expid + """ + result = search_experiment_by_id( + query=expid, owner=None, exp_type=None, only_active=None + ) + return result + + +@router.get( + "/search/{expid}/{exptype}/{onlyactive}", + name="Search Experiment by expid, type and active status", +) +async def search_expid_plus( + expid: str, + exptype: str, + onlyactive: str, + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Search for experiments by expid, type and active status + """ + result = search_experiment_by_id( + query=expid, owner=None, exp_type=exptype, only_active=onlyactive + ) + return result + + +@router.get("/running/", name="Search Running Experiments") +async def search_running( + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Returns the list of all experiments that are currently running. + """ + result = get_current_running_exp() + return result + + +@router.get("/runs/{expid}", name="Get Experiment Runs") +async def get_runs( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get list of runs of the same experiment from the historical db + """ + result = CommonRequests.get_experiment_runs(expid) + return result + + +@router.get("/ifrun/{expid}", name="Check if Experiment is Running") +async def get_if_running( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Quick check if the experiment is currently running. + """ + result = CommonRequests.quick_test_run(expid) + return result + + +@router.get("/logrun/{expid}", name="Get Experiment running status and log path") +async def get_running_detail( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get Experiment running status and log path. + """ + result = CommonRequests.get_current_status_log_plus(expid) + return result + + +@router.get("/summary/{expid}", name="Get Experiment Summary") +async def get_exp_summary( + expid: str, + loggedUser: str = Query("null"), + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Summary + """ + if loggedUser != "null": + lock.acquire() + D[os.getpid()] = [loggedUser, "summary", True] + lock.release() + result = CommonRequests.get_experiment_summary(expid, logger) + logger.info("Process: " + str(os.getpid()) + " workers: " + str(D)) + if loggedUser != "null": + lock.acquire() + D[os.getpid()] = [loggedUser, "summary", False] + lock.release() + if loggedUser != "null": + lock.acquire() + D.pop(os.getpid(), None) + lock.release() + return result + + +@router.get("/shutdown/{route}", name="Shutdown Experiment processes") +async def shutdown( + route: str, + loggedUser: str = Query("null"), + expid: str = Query("null"), + user_id: Optional[str] = Depends(auth_token_dependency()), +): + """ + This function is invoked from the frontend (AS-GUI) to kill workers that are no longer needed. + This call is common in heavy parts of the GUI such as the Tree and Graph generation or Summaries fetching. + """ + if loggedUser != "null": + logger.info( + "SHUTDOWN|DETAILS|route: " + + route + + " user: " + + loggedUser + + " expid: " + + expid + ) + try: + # logger.info("user: " + user) + # logger.info("expid: " + expid) + logger.info("Workers before: " + str(D)) + lock.acquire() + for k, v in list(D.items()): + if v[0] == loggedUser and v[1] == route and v[-1] is True: + if v[2] == expid: + D[k] = [loggedUser, route, expid, False] + else: + D[k] = [loggedUser, route, False] + D.pop(k, None) + # reboot the worker + os.system("kill -HUP " + str(k)) + logger.info("killed worker " + str(k)) + lock.release() + logger.info("Workers now: " + str(D)) + except Exception: + logger.info( + "[CRITICAL] Could not shutdown process " + + expid + + ' by user "' + + loggedUser + + '"' + ) + return "" + + +@router.get("/performance/{expid}", name="Get Experiment Performance") +async def get_exp_performance( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get Experiment Performance + """ + result = {} + try: + result = PerformanceMetrics( + expid, + JobListHelperDirector(JobListHelperBuilder(expid)).build_job_list_helper(), + ).to_json() + except Exception as exc: + result = { + "SYPD": None, + "ASYPD": None, + "RSYPD": None, + "CHSY": None, + "JPSY": None, + "Parallelization": None, + "PE": None, + "considered": [], + "error": True, + "error_message": str(exc), + "warnings_job_data": [], + } + return result + + +@router.get("/graph/{expid}/{layout}/{grouped}", name="Get Experiment Graph") +async def get_graph_format( + expid: str, + layout: Literal["standard", "laplacian"], + grouped: Literal["none", "status", "date-member"], + loggedUser: str = Query("null"), + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Graph + """ + if loggedUser != "null": + lock.acquire() + D[os.getpid()] = [loggedUser, "graph", expid, True] + lock.release() + result = CommonRequests.get_experiment_graph(expid, logger, layout, grouped) + logger.info("Process: " + str(os.getpid()) + " graph workers: " + str(D)) + if loggedUser != "null": + lock.acquire() + D[os.getpid()] = [loggedUser, "graph", expid, False] + lock.release() + if loggedUser != "null": + lock.acquire() + D.pop(os.getpid(), None) + lock.release() + return result + + +@router.get("/tree/{expid}", name="Get Experiment Tree") +async def get_exp_tree( + expid: str, + loggedUser: str = Query("null"), + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Tree + """ + if loggedUser != "null": + lock.acquire() + D[os.getpid()] = [loggedUser, "tree", expid, True] + lock.release() + result = CommonRequests.get_experiment_tree_structured(expid, logger) + logger.info("Process: " + str(os.getpid()) + " tree workers: " + str(D)) + if loggedUser != "null": + lock.acquire() + D[os.getpid()] = [loggedUser, "tree", expid, False] + lock.release() + if loggedUser != "null": + lock.acquire() + D.pop(os.getpid(), None) + lock.release() + return result + + +@router.get("/quick/{expid}", name="Get Quick View Data") +async def get_quick_view_data( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get Quick View Data + """ + result = CommonRequests.get_quick_view(expid) + return result + + +@router.get("/exprun/{expid}", name="Get Experiment Run Log") +async def get_experiment_run_log( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Finds log and gets the last 150 lines + """ + result = CommonRequests.get_experiment_log_last_lines(expid) + return result + + +@router.get("/joblog/{logfile}", name="Get Job Log from Path") +async def get_job_log_from_path( + logfile: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> dict: + """ + Get Job Log from Path + """ + expid = logfile.split("_") if logfile is not None else "" + expid = expid[0] if len(expid) > 0 else "" + result = CommonRequests.get_job_log(expid, logfile) + return result + + +@router.get( + "/pklinfo/{expid}/{timeStamp}", name="Get Experiment Pickle Info for Graph View" +) +async def get_experiment_pklinfo( + expid: str, + timeStamp: str = Path(description="Unused path parameter", example="0"), + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Pickle Info for Graph View + """ + result = CommonRequests.get_experiment_pkl(expid) + return result + + +@router.get( + "/pkltreeinfo/{expid}/{timeStamp}", name="Get Experiment Pickle Info for Tree View" +) +async def get_experiment_tree_pklinfo( + expid: str, + timeStamp: str = Path(description="Unused path parameter", example="0"), + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Pickle Info for Tree View + """ + result = CommonRequests.get_experiment_tree_pkl(expid) + return result + + +@router.get( + "/stats/{expid}/{filter_period}/{filter_type}", name="Get Experiment Statistics" +) +async def get_experiment_statistics( + expid: str, + filter_period: str, + filter_type: str = Path( + example="Any", description="Job Section filter, use 'Any' for all" + ), + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Statistics + """ + result = CommonRequests.get_experiment_stats(expid, filter_period, filter_type) + return result + + +@router.get("/history/{expid}/{jobname}", name="Get Experiment Job History") +async def get_exp_job_history( + expid: str, + jobname: str, + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Job History + """ + result = CommonRequests.get_job_history(expid, jobname) + return result + + +@router.get("/rundetail/{expid}/{runid}", name="Get Experiment Run Job Detail") +async def get_experiment_run_job_detail( + expid: str, + runid: str, + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> dict: + """ + Get Experiment Run Job Detail + """ + result = CommonRequests.get_experiment_tree_rundetail(expid, runid) + return result + + +@router.get("/filestatus", name="[UNSUPPORTED] Get File Status") +async def get_file_status() -> dict: + """ + This endpoint is not supported in this version of the API. + Will be removed in future versions. + """ + return { + "status": False, + "error": False, + "error_message": ( + "This endpoint is not supported in this version of the API." + "Will be removed in future versions." + ), + "avg_latency": None, + "avg_bandwidth": None, + "current_latency": None, + "current_bandwidth": None, + "reponse_time": None, + "datetime": None, + "latency_warning": None, + "bandwidth_warning": None, + "response_warning": None, + } diff --git a/autosubmit_api/routers/v4/__init__.py b/autosubmit_api/routers/v4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de40c2aae4cb7c6bcf900fa9e400d21a57485b14 --- /dev/null +++ b/autosubmit_api/routers/v4/__init__.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter, Request + +from autosubmit_api.models.responses import RoutesResponse +from autosubmit_api.routers.v4 import experiments, auth + +router = APIRouter() + + +@router.get("", name="v4 routes index") +async def v4_root_index(request: Request) -> RoutesResponse: + """ + Returns routes of this router + """ + routes_info = [ + {"path": f"{request.url.path}{route.path}", "methods": route.methods} + for route in router.routes + ] + return {"routes": routes_info} + + +router.include_router(auth.router, prefix="/auth") +router.include_router(experiments.router, prefix="/experiments") diff --git a/autosubmit_api/routers/v4/auth.py b/autosubmit_api/routers/v4/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..63aae7b6676f329de28fac56e3939b4dc1caaa4e --- /dev/null +++ b/autosubmit_api/routers/v4/auth.py @@ -0,0 +1,200 @@ +from datetime import datetime, timedelta +from http import HTTPStatus +from typing import Optional +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, RedirectResponse +import jwt +from cas import CASClient +import requests +from autosubmit_api import config +from autosubmit_api.auth import ProtectionLevels, auth_token_dependency +from autosubmit_api.auth.utils import validate_client +from autosubmit_api.models.responses import AuthResponse, LoginResponse + + +router = APIRouter() + + +@router.get("/verify-token", name="Verify JWT token") +async def verify_token( + user_id: Optional[str] = Depends( + auth_token_dependency(threshold=ProtectionLevels.NONE, raise_on_fail=False) + ), +) -> AuthResponse: + """ + Verify JWT endpoint. + """ + return JSONResponse( + status_code=(HTTPStatus.OK if user_id else HTTPStatus.UNAUTHORIZED), + content={ + "authenticated": True if user_id else False, + "user": user_id, + }, + ) + + +@router.get("/cas/v2/login", name="CAS v2 login") +async def cas_v2_login( + request: Request, service: Optional[str] = None, ticket: Optional[str] = None +) -> LoginResponse: + """ + CAS v2 login endpoint. + """ + if not service: + service = request.base_url + + is_allowed_service = (service == request.base_url) or validate_client(service) + + if not is_allowed_service: + return JSONResponse( + content={ + "authenticated": False, + "user": None, + "token": None, + "message": "Your service is not authorized for this operation. The API admin needs to add your URL to the list of allowed clients.", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) + + cas_client = CASClient( + version=2, service_url=service, server_url=config.CAS_SERVER_URL + ) + + if not ticket: + # No ticket, the request come from end user, send to CAS login + cas_login_url = cas_client.get_login_url() + return RedirectResponse(url=cas_login_url) + + # There is a ticket, the request come from CAS as callback. + # need call `verify_ticket()` to validate ticket and get user profile. + user, attributes, pgtiou = cas_client.verify_ticket(ticket) + + if not user: + return JSONResponse( + content={ + "authenticated": False, + "user": None, + "token": None, + "message": "Can't verify user", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) + else: # Login successful + payload = { + "user_id": user, + "sub": user, + "iat": int(datetime.now().timestamp()), + "exp": (datetime.now() + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS)), + } + jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) + return JSONResponse( + content={ + "authenticated": True, + "user": user, + "token": f"Bearer {jwt_token}", + "message": "Token generated", + }, + status_code=HTTPStatus.OK, + ) + + +@router.get("/oauth2/github/login", name="Github OAuth2 login") +async def github_oauth2_login(code: Optional[str] = None) -> LoginResponse: + """ + Authenticate and authorize user using a cofigured GitHub Oauth app. + The authorization in done by verifying users membership to either a Github Team + or Organization. + """ + if not code: + return JSONResponse( + content={ + "authenticated": False, + "user": None, + "token": None, + "message": "Can't verify user", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) + + resp_obj: dict = requests.post( + "https://github.com/login/oauth/access_token", + data={ + "client_id": config.GITHUB_OAUTH_CLIENT_ID, + "client_secret": config.GITHUB_OAUTH_CLIENT_SECRET, + "code": code, + }, + headers={"Accept": "application/json"}, + ).json() + access_token = resp_obj.get("access_token") + + user_info: dict = requests.get( + "https://api.github.com/user", + headers={"Authorization": f"Bearer {access_token}"}, + ).json() + username = user_info.get("login") + + if not username: + return JSONResponse( + content={ + "authenticated": False, + "user": None, + "token": None, + "message": "Couldn't find user on GitHub", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) + + # Whitelist organization team + if ( + config.GITHUB_OAUTH_WHITELIST_ORGANIZATION + and config.GITHUB_OAUTH_WHITELIST_TEAM + ): + org_resp = requests.get( + f"https://api.github.com/orgs/{config.GITHUB_OAUTH_WHITELIST_ORGANIZATION}/teams/{config.GITHUB_OAUTH_WHITELIST_TEAM}/memberships/{username}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + membership: dict = org_resp.json() + is_member = ( + org_resp.status_code == 200 and membership.get("state") == "active" + ) # https://docs.github.com/en/rest/teams/members?apiVersion=2022-11-28#get-team-membership-for-a-user + elif ( + config.GITHUB_OAUTH_WHITELIST_ORGANIZATION + ): # Whitelist all organization (no team) + org_resp = requests.get( + f"https://api.github.com/orgs/{config.GITHUB_OAUTH_WHITELIST_ORGANIZATION}/members/{username}", + headers={"Authorization": f"Bearer {access_token}"}, + ) + is_member = ( + org_resp.status_code == 204 + ) # https://docs.github.com/en/rest/orgs/members?apiVersion=2022-11-28#check-organization-membership-for-a-user + else: # No authorization check + is_member = True + + # Login successful + if is_member: + payload = { + "user_id": username, + "sub": username, + "iat": int(datetime.now().timestamp()), + "exp": (datetime.now() + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS)), + } + jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) + return JSONResponse( + content={ + "authenticated": True, + "user": username, + "token": f"Bearer {jwt_token}", + "message": "Token generated", + }, + status_code=HTTPStatus.OK, + ) + else: # UNAUTHORIZED + return JSONResponse( + content={ + "authenticated": False, + "user": None, + "token": None, + "message": "User is not member of organization or team", + }, + status_code=HTTPStatus.UNAUTHORIZED, + ) diff --git a/autosubmit_api/routers/v4/experiments.py b/autosubmit_api/routers/v4/experiments.py new file mode 100644 index 0000000000000000000000000000000000000000..05225f713c21c039efbbb9902b7331cc22878b3e --- /dev/null +++ b/autosubmit_api/routers/v4/experiments.py @@ -0,0 +1,351 @@ +import asyncio +from collections import deque +from datetime import datetime, timezone +from http import HTTPStatus +import json +import math +import traceback +from typing import Annotated, Any, Dict, List, Literal, Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import JSONResponse +from autosubmit_api.auth import auth_token_dependency +from autosubmit_api.builders.experiment_builder import ExperimentBuilder +from autosubmit_api.config.config_common import AutosubmitConfigResolver +from autosubmit_api.database import tables +from autosubmit_api.database.common import ( + create_main_db_conn, + execute_with_limit_offset, +) +from autosubmit_api.common.utils import Status +from autosubmit_api.database.db_jobdata import JobDataStructure +from autosubmit_api.database.models import BaseExperimentModel +from autosubmit_api.database.queries import generate_query_listexp_extended +from autosubmit_api.logger import logger +from autosubmit_api.builders.experiment_history_builder import ( + ExperimentHistoryBuilder, + ExperimentHistoryDirector, +) +from autosubmit_api.models.requests import ( + ExperimentsSearchRequest, +) +from autosubmit_api.models.responses import ( + ExperimentFSConfigResponse, + ExperimentJobsResponse, + ExperimentRunConfigResponse, + ExperimentRunsResponse, + ExperimentWrappersResponse, + ExperimentsSearchResponse, +) +from autosubmit_api.persistance.job_package_reader import JobPackageReader +from autosubmit_api.persistance.pkl_reader import PklReader +from bscearth.utils.config_parser import ConfigParserFactory +from autosubmit_api.config.basicConfig import APIBasicConfig +from autosubmit_api.config.confConfigStrategy import confConfigStrategy + +router = APIRouter() + + +@router.get("", name="Search experiments") +async def search_experiments( + query_params: Annotated[ExperimentsSearchRequest, Query()], + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> ExperimentsSearchResponse: + """ + Search experiments + """ + logger.debug(f"Search args: {query_params}") + + if query_params.page_size > 0: + offset = (query_params.page - 1) * query_params.page_size + else: + offset = None + query_params.page_size = None + + # Query + statement = generate_query_listexp_extended( + query=query_params.query, + only_active=query_params.only_active, + owner=query_params.owner, + exp_type=query_params.exp_type, + autosubmit_version=query_params.autosubmit_version, + order_by=query_params.order_by, + order_desc=query_params.order_desc, + ) + with create_main_db_conn() as conn: + query_result, total_rows = execute_with_limit_offset( + statement=statement, + conn=conn, + limit=query_params.page_size, + offset=offset, + ) + + async def _get_experiment(raw_exp): + exp_builder = ExperimentBuilder() + exp_builder.produce_base_from_dict(raw_exp._mapping) + exp_builder.produce_pkl_modified_time() + exp = exp_builder.product + + # Get current run data from history + # last_modified_timestamp = exp.created + completed = 0 + total = 0 + submitted = 0 + queuing = 0 + running = 0 + failed = 0 + suspended = 0 + try: + current_run = ( + ExperimentHistoryDirector(ExperimentHistoryBuilder(exp.name)) + .build_reader_experiment_history() + .manager.get_experiment_run_dc_with_max_id() + ) + if current_run and current_run.total > 0: + completed = current_run.completed + total = current_run.total + submitted = current_run.submitted + queuing = current_run.queuing + running = current_run.running + failed = current_run.failed + suspended = current_run.suspended + # last_modified_timestamp = current_run.modified_timestamp + except Exception as exc: + logger.warning((f"Exception getting the current run on search: {exc}")) + logger.warning(traceback.format_exc()) + + # Format data + return { + "id": exp.id, + "name": exp.name, + "user": exp.user, + "description": exp.description, + "hpc": exp.hpc, + "version": exp.autosubmit_version, + # "wrapper": exp.wrapper, + "created": exp.created, + "modified": exp.modified, + "status": exp.status if exp.status else "NOT RUNNING", + "completed": completed, + "total": total, + "submitted": submitted, + "queuing": queuing, + "running": running, + "failed": failed, + "suspended": suspended, + } + + # Process experiments + experiments = await asyncio.gather( + *[_get_experiment(raw_exp) for raw_exp in query_result] + ) + + # Response + response = { + "experiments": experiments, + "pagination": { + "page": query_params.page, + "page_size": query_params.page_size, + "total_pages": math.ceil(total_rows / query_params.page_size) + if query_params.page_size + else 1, + "page_items": len(experiments), + "total_items": total_rows, + }, + } + return JSONResponse(response) # TODO Use Validation. Not respond directly. + + +@router.get("/{expid}", name="Get experiment detail") +async def get_experiment_detail( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> BaseExperimentModel: + """ + Get details of an experiment + """ + exp_builder = ExperimentBuilder() + exp_builder.produce_base(expid) + return exp_builder.product.model_dump(include=tables.experiment_table.c.keys()) + + +@router.get("/{expid}/jobs", name="List experiment jobs") +async def get_experiment_jobs( + expid: str, + view: Annotated[Literal["quick", "base"], Query()] = "base", + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> ExperimentJobsResponse: + """ + Get the experiment jobs from pickle file. + BASE view returns base content of the pkl file. + QUICK view returns a reduced payload with just the name and status of the jobs. + """ + # Read the pkl + try: + current_content = PklReader(expid).parse_job_list() + except Exception as exc: + error_message = "Error while reading the job list" + logger.error(error_message + f": {exc}") + logger.error(traceback.print_exc()) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=error_message + ) + + pkl_jobs = deque() + for job_item in current_content: + resp_job = { + "name": job_item.name, + "status": Status.VALUE_TO_KEY.get(job_item.status, Status.UNKNOWN), + } + + if view == "base": + resp_job = { + **resp_job, + "priority": job_item.priority, + "section": job_item.section, + "date": ( + job_item.date.date().isoformat() + if isinstance(job_item.date, datetime) + else None + ), + "member": job_item.member, + "chunk": job_item.chunk, + "out_path_local": job_item.out_path_local, + "err_path_local": job_item.err_path_local, + "out_path_remote": job_item.out_path_remote, + "err_path_remote": job_item.err_path_remote, + } + + if job_item.status in [Status.COMPLETED, Status.WAITING, Status.READY]: + pkl_jobs.append(resp_job) + else: + pkl_jobs.appendleft(resp_job) + + return JSONResponse( + {"jobs": list(pkl_jobs)} + ) # TODO Use Validation. Not respond directly. + + +@router.get("/{expid}/wrappers", name="Get experiment wrappers") +async def get_experiment_wrappers( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> ExperimentWrappersResponse: + """ + Get wrappers for an experiment + """ + job_package_reader = JobPackageReader(expid) + job_package_reader.read() + + wrappers_dict: Dict[str, List[str]] = job_package_reader.package_to_jobs + + wrappers = [] + for key, val in wrappers_dict.items(): + wrappers.append({"wrapper_name": key, "job_names": val}) + + logger.debug(wrappers) + return {"wrappers": wrappers} + + +def _format_config_response( + config: Dict[str, Any], is_as3: bool = False +) -> Dict[str, Any]: + """ + Format the config response, removing some keys if it's an AS3 config + Also, add a key to indicate if the config is empty + :param config: The config to format + :param is_as3: If the config is an AS3 config + """ + ALLOWED_CONFIG_KEYS = ["conf", "exp", "jobs", "platforms", "proj"] + formatted_config = { + key: config[key] + for key in config + if not is_as3 or (key.lower() in ALLOWED_CONFIG_KEYS) + } + formatted_config["contains_nones"] = not config or (None in list(config.values())) + return formatted_config + + +@router.get( + "/{expid}/filesystem-config", name="Get experiment current filesystem configuration" +) +async def get_experiment_fs_config( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> ExperimentFSConfigResponse: + """ + Get the filesystem config of an experiment + """ + # Read the config + APIBasicConfig.read() + as_config = AutosubmitConfigResolver(expid, APIBasicConfig, ConfigParserFactory()) + is_as3 = isinstance(as_config._configWrapper, confConfigStrategy) + as_config.reload() + curr_fs_config: Dict[str, Any] = as_config.get_full_config_as_dict() + + # Format the response + response = {"config": _format_config_response(curr_fs_config, is_as3)} + return response + + +@router.get("/{expid}/runs", name="List experiment runs") +async def get_runs( + expid: str, user_id: Optional[str] = Depends(auth_token_dependency()) +) -> ExperimentRunsResponse: + """ + Get runs for a given experiment + """ + try: + experiment_history = ExperimentHistoryDirector( + ExperimentHistoryBuilder(expid) + ).build_reader_experiment_history() + exp_runs = experiment_history.get_experiment_runs() + except Exception: + logger.error("Error while getting experiment runs") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Error while getting experiment runs", + ) + + # Format the response + response = {"runs": []} + for run in exp_runs: + response["runs"].append( + { + "run_id": run.run_id, + "start": datetime.fromtimestamp(run.start, timezone.utc).isoformat( + timespec="seconds" + ) + if run.start > 0 + else None, + "finish": datetime.fromtimestamp(run.finish, timezone.utc).isoformat( + timespec="seconds" + ) + if run.finish > 0 + else None, + } + ) + + return response + + +@router.get("/{expid}/runs/{run_id}/config", name="Get experiment run configuration") +async def get_run_config( + expid: str, + run_id: str, + user_id: Optional[str] = Depends(auth_token_dependency()), +) -> ExperimentRunConfigResponse: + """ + Get the config of a specific run of an experiment + """ + historical_db = JobDataStructure(expid, APIBasicConfig) + experiment_run = historical_db.get_experiment_run_by_id(run_id=run_id) + metadata = ( + json.loads(experiment_run.metadata) + if experiment_run and experiment_run.metadata + else {} + ) + + # Format the response + response = { + "run_id": experiment_run.run_id if experiment_run else None, + "config": _format_config_response(metadata), + } + return response diff --git a/autosubmit_api/views/__init__.py b/autosubmit_api/views/__init__.py deleted file mode 100644 index 8d6fd412162174135dc692a4065ad03439ebd016..0000000000000000000000000000000000000000 --- a/autosubmit_api/views/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from autosubmit_api.logger import with_log_run_times, logger -from autosubmit_api import __version__ as APIVersion -from http import HTTPStatus -from werkzeug.exceptions import HTTPException - - -def handle_HTTP_exception(e: HTTPException): - """Return JSON instead of HTML for HTTP errors.""" - return {"error": True, "error_message": e.name}, e.code - - -@with_log_run_times(logger, "HOME") -def home(): - return {"name": "Autosubmit API", "version": APIVersion}, HTTPStatus.OK diff --git a/autosubmit_api/views/v3.py b/autosubmit_api/views/v3.py deleted file mode 100644 index 49f5a645568368f9a434662d7fca6070a44e7ed1..0000000000000000000000000000000000000000 --- a/autosubmit_api/views/v3.py +++ /dev/null @@ -1,418 +0,0 @@ -from http import HTTPStatus -import os -from datetime import datetime, timedelta -from typing import Optional -import requests -from flask_cors import cross_origin -from flask import request, session, redirect -from autosubmit_api.auth import ProtectionLevels, with_auth_token -from autosubmit_api.database.db_common import ( - get_current_running_exp, - update_experiment_description_owner, -) -from autosubmit_api.experiment import common_requests as CommonRequests -from autosubmit_api.experiment import utils as Utiles -from autosubmit_api.logger import logger, with_log_run_times -from autosubmit_api.performance.performance_metrics import PerformanceMetrics -from autosubmit_api.database.db_common import search_experiment_by_id -from autosubmit_api.config.basicConfig import APIBasicConfig -from autosubmit_api.builders.joblist_helper_builder import ( - JobListHelperBuilder, - JobListHelperDirector, -) -from multiprocessing import Manager, Lock -import jwt -from autosubmit_api.config import ( - JWT_SECRET, - JWT_ALGORITHM, - JWT_EXP_DELTA_SECONDS, - CAS_LOGIN_URL, - CAS_VERIFY_URL, -) - - -# Multiprocessing setup -D = Manager().dict() -lock = Lock() - - -@with_log_run_times(logger, "LOGIN") -def login(): - APIBasicConfig.read() - ticket = request.args.get("ticket") - environment = request.args.get("env") - referrer = request.referrer - is_allowed = False - for allowed_client in APIBasicConfig.ALLOWED_CLIENTS: - if referrer and referrer.find(allowed_client) >= 0: - referrer = allowed_client - is_allowed = True - if is_allowed is False: - return { - "authenticated": False, - "user": None, - "token": None, - "message": "Your client is not authorized for this operation. The API admin needs to add your URL to the list of allowed clients.", - }, HTTPStatus.UNAUTHORIZED - - target_service = "{}{}/login".format(referrer, environment) - if not ticket: - route_to_request_ticket = "{}?service={}".format(CAS_LOGIN_URL, target_service) - logger.info("Redirected to: " + str(route_to_request_ticket)) - return redirect(route_to_request_ticket) - # can be used to target the test environment - # environment = environment if environment is not None else "autosubmitapp" - cas_verify_ticket_route = ( - CAS_VERIFY_URL + "?service=" + target_service + "&ticket=" + ticket - ) - response = requests.get(cas_verify_ticket_route) - user = None - if response: - user = Utiles.get_cas_user_from_xml(response.content) - logger.info("CAS verify ticket response: user %s", user) - if not user: - return { - "authenticated": False, - "user": None, - "token": None, - "message": "Can't verify user.", - }, HTTPStatus.UNAUTHORIZED - else: # Login successful - payload = { - "user_id": user, - "sub": user, - "iat": int(datetime.now().timestamp()), - "exp": (datetime.utcnow() + timedelta(seconds=JWT_EXP_DELTA_SECONDS)), - } - jwt_token = jwt.encode(payload, JWT_SECRET, JWT_ALGORITHM) - return { - "authenticated": True, - "user": user, - "token": jwt_token, - "message": "Token generated.", - } - - -@cross_origin(expose_headers="Authorization") -@with_log_run_times(logger, "TTEST") -@with_auth_token(threshold=ProtectionLevels.NONE, response_on_fail=False) -def test_token(user_id: Optional[str] = None): - """ - Tests if a token is still valid - """ - return { - "isValid": True if user_id else False, - "message": "Unauthorized" if not user_id else None, - }, HTTPStatus.OK if user_id else HTTPStatus.UNAUTHORIZED - - -@cross_origin(expose_headers="Authorization") -@with_log_run_times(logger, "UDESC") -@with_auth_token(threshold=ProtectionLevels.WRITEONLY) -def update_description(user_id: Optional[str] = None): - """ - Updates the description of an experiment. Requires authenticated user. - """ - expid = None - new_description = None - if request.is_json: - body_data = request.json - expid = body_data.get("expid", None) - new_description = body_data.get("description", None) - return ( - update_experiment_description_owner(expid, new_description, user_id), - HTTPStatus.OK if user_id else HTTPStatus.UNAUTHORIZED, - ) - - -@cross_origin(expose_headers="Authorization") -@with_log_run_times(logger, "CCONFIG") -@with_auth_token() -def get_current_configuration(expid: str, user_id: Optional[str] = None): - result = CommonRequests.get_current_configuration_by_expid(expid, user_id) - return result - - -@with_log_run_times(logger, "EXPINFO") -@with_auth_token() -def exp_info(expid: str, user_id: Optional[str] = None): - result = CommonRequests.get_experiment_data(expid) - return result - - -@with_log_run_times(logger, "EXPCOUNT") -@with_auth_token() -def exp_counters(expid: str, user_id: Optional[str] = None): - result = CommonRequests.get_experiment_counters(expid) - return result - - -@with_log_run_times(logger, "SOWNER") -@with_auth_token() -def search_owner(owner, exptype=None, onlyactive=None, user_id: Optional[str] = None): - """ - Same output format as search_expid - """ - result = search_experiment_by_id( - query=None, owner=owner, exp_type=exptype, only_active=onlyactive - ) - return result - - -@with_log_run_times(logger, "SEARCH") -@with_auth_token() -def search_expid(expid, exptype=None, onlyactive=None, user_id: Optional[str] = None): - result = search_experiment_by_id( - query=expid, owner=None, exp_type=exptype, only_active=onlyactive - ) - return result - - -@with_log_run_times(logger, "RUN") -@with_auth_token() -def search_running(user_id: Optional[str] = None): - """ - Returns the list of all experiments that are currently running. - """ - if "username" in session: - logger.debug(("USER {}".format(session["username"]))) - logger.debug("Active proceses: " + str(D)) - # logger.info("Received Currently Running query ") - result = get_current_running_exp() - return result - - -@with_log_run_times(logger, "ERUNS") -@with_auth_token() -def get_runs(expid, user_id: Optional[str] = None): - """ - Get list of runs of the same experiment from the historical db - """ - result = CommonRequests.get_experiment_runs(expid) - return result - - -@with_log_run_times(logger, "IFRUN") -@with_auth_token() -def get_if_running(expid, user_id: Optional[str] = None): - result = CommonRequests.quick_test_run(expid) - return result - - -@with_log_run_times(logger, "RUNDET") -@with_auth_token() -def get_running_detail(expid, user_id: Optional[str] = None): - result = CommonRequests.get_current_status_log_plus(expid) - return result - - -@with_log_run_times(logger, "SUMMARY") -@with_auth_token() -def get_expsummary(expid, user_id: Optional[str] = None): - user = request.args.get("loggedUser", default="null", type=str) - if user != "null": - lock.acquire() - D[os.getpid()] = [user, "summary", True] - lock.release() - result = CommonRequests.get_experiment_summary(expid, logger) - logger.info("Process: " + str(os.getpid()) + " workers: " + str(D)) - if user != "null": - lock.acquire() - D[os.getpid()] = [user, "summary", False] - lock.release() - if user != "null": - lock.acquire() - D.pop(os.getpid(), None) - lock.release() - return result - - -@with_log_run_times(logger, "SHUTDOWN") -@with_auth_token() -def shutdown(route, user_id: Optional[str] = None): - """ - This function is invoked from the frontend (AS-GUI) to kill workers that are no longer needed. - This call is common in heavy parts of the GUI such as the Tree and Graph generation or Summaries fetching. - """ - try: - user = request.args.get("loggedUser", default="null", type=str) - expid = request.args.get("expid", default="null", type=str) - except Exception: - logger.info("Bad parameters for user and expid in route.") - - if user != "null": - logger.info( - "SHUTDOWN|DETAILS|route: " + route + " user: " + user + " expid: " + expid - ) - try: - # logger.info("user: " + user) - # logger.info("expid: " + expid) - logger.info("Workers before: " + str(D)) - lock.acquire() - for k, v in list(D.items()): - if v[0] == user and v[1] == route and v[-1] is True: - if v[2] == expid: - D[k] = [user, route, expid, False] - else: - D[k] = [user, route, False] - D.pop(k, None) - # reboot the worker - os.system("kill -HUP " + str(k)) - logger.info("killed worker " + str(k)) - lock.release() - logger.info("Workers now: " + str(D)) - except Exception: - logger.info( - "[CRITICAL] Could not shutdown process " - + expid - + ' by user "' - + user - + '"' - ) - return "" - - -@with_log_run_times(logger, "PRF") -@with_auth_token() -def get_exp_performance(expid, user_id: Optional[str] = None): - result = {} - try: - result = PerformanceMetrics( - expid, - JobListHelperDirector(JobListHelperBuilder(expid)).build_job_list_helper(), - ).to_json() - except Exception as exc: - result = { - "SYPD": None, - "ASYPD": None, - "RSYPD": None, - "CHSY": None, - "JPSY": None, - "Parallelization": None, - "PE": None, - "considered": [], - "error": True, - "error_message": str(exc), - "warnings_job_data": [], - } - return result - - -@with_log_run_times(logger, "GRAPH") -@with_auth_token() -def get_graph_format( - expid, layout="standard", grouped="none", user_id: Optional[str] = None -): - user = request.args.get("loggedUser", default="null", type=str) - # logger.info("user: " + user) - # logger.info("expid: " + expid) - if user != "null": - lock.acquire() - D[os.getpid()] = [user, "graph", expid, True] - lock.release() - result = CommonRequests.get_experiment_graph(expid, logger, layout, grouped) - logger.info("Process: " + str(os.getpid()) + " graph workers: " + str(D)) - if user != "null": - lock.acquire() - D[os.getpid()] = [user, "graph", expid, False] - lock.release() - if user != "null": - lock.acquire() - D.pop(os.getpid(), None) - lock.release() - return result - - -@with_log_run_times(logger, "TREE") -@with_auth_token() -def get_exp_tree(expid, user_id: Optional[str] = None): - user = request.args.get("loggedUser", default="null", type=str) - # logger.info("user: " + user) - # logger.info("expid: " + expid) - if user != "null": - lock.acquire() - D[os.getpid()] = [user, "tree", expid, True] - lock.release() - result = CommonRequests.get_experiment_tree_structured(expid, logger) - logger.info("Process: " + str(os.getpid()) + " tree workers: " + str(D)) - if user != "null": - lock.acquire() - D[os.getpid()] = [user, "tree", expid, False] - lock.release() - if user != "null": - lock.acquire() - D.pop(os.getpid(), None) - lock.release() - return result - - -@with_log_run_times(logger, "QUICK") -@with_auth_token(response_on_fail=True) -def get_quick_view_data(expid, user_id=None): - result = CommonRequests.get_quick_view(expid) - return result - - -@with_log_run_times(logger, "LOG") -@with_auth_token() -def get_experiment_run_log(expid, user_id: Optional[str] = None): - """ - Finds log and gets the last 150 lines - """ - result = CommonRequests.get_experiment_log_last_lines(expid) - return result - - -@with_log_run_times(logger, "JOBLOG") -@with_auth_token() -def get_job_log_from_path(logfile, user_id: Optional[str] = None): - """ - Get log from path - """ - expid = logfile.split("_") if logfile is not None else "" - expid = expid[0] if len(expid) > 0 else "" - result = CommonRequests.get_job_log(expid, logfile) - return result - - -@with_log_run_times(logger, "GPKL") -@with_auth_token() -def get_experiment_pklinfo(expid, timeStamp=None, user_id: Optional[str] = None): - result = CommonRequests.get_experiment_pkl(expid) - return result - - -@with_log_run_times(logger, "TPKL") -@with_auth_token() -def get_experiment_tree_pklinfo(expid, timeStamp=None, user_id: Optional[str] = None): - result = CommonRequests.get_experiment_tree_pkl(expid) - return result - - -@with_log_run_times(logger, "STAT") -@with_auth_token() -def get_experiment_statistics( - expid, filter_period, filter_type, user_id: Optional[str] = None -): - result = CommonRequests.get_experiment_stats(expid, filter_period, filter_type) - return result - - -@with_log_run_times(logger, "HISTORY") -@with_auth_token() -def get_exp_job_history(expid, jobname, user_id: Optional[str] = None): - result = CommonRequests.get_job_history(expid, jobname) - return result - - -@with_log_run_times(logger, "RUNDETAIL") -@with_auth_token() -def get_experiment_run_job_detail(expid, runid, user_id: Optional[str] = None): - result = CommonRequests.get_experiment_tree_rundetail(expid, runid) - return result - - -@with_log_run_times(logger, "FSTATUS") -def get_file_status(): - result = CommonRequests.get_last_test_archive_status() - return result diff --git a/autosubmit_api/views/v4.py b/autosubmit_api/views/v4.py deleted file mode 100644 index 138364bec482e092938c2a4d5b5ce22e424c5940..0000000000000000000000000000000000000000 --- a/autosubmit_api/views/v4.py +++ /dev/null @@ -1,533 +0,0 @@ -from collections import deque -from datetime import datetime, timedelta, timezone -from enum import Enum -from http import HTTPStatus -import json -import math -import traceback -from typing import Any, Dict, List, Optional -from flask import redirect, request -from flask.views import MethodView -import jwt -import requests -from autosubmit_api.auth import ProtectionLevels, with_auth_token -from autosubmit_api.auth.utils import validate_client -from autosubmit_api.builders.experiment_builder import ExperimentBuilder -from autosubmit_api.builders.experiment_history_builder import ( - ExperimentHistoryBuilder, - ExperimentHistoryDirector, -) -from autosubmit_api.common.utils import Status -from autosubmit_api.config.basicConfig import APIBasicConfig -from autosubmit_api.config.confConfigStrategy import confConfigStrategy -from autosubmit_api.config.config_common import AutosubmitConfigResolver -from autosubmit_api.database import tables -from autosubmit_api.database.common import ( - create_main_db_conn, - execute_with_limit_offset, -) -from autosubmit_api.database.db_jobdata import JobDataStructure -from autosubmit_api.database.queries import generate_query_listexp_extended -from autosubmit_api.logger import logger, with_log_run_times -from cas import CASClient -from autosubmit_api import config -from autosubmit_api.persistance.job_package_reader import JobPackageReader -from autosubmit_api.persistance.pkl_reader import PklReader -from bscearth.utils.config_parser import ConfigParserFactory - - -PAGINATION_LIMIT_DEFAULT = 12 - - -class CASV2Login(MethodView): - decorators = [with_log_run_times(logger, "CASV2LOGIN")] - - def get(self): - ticket = request.args.get("ticket") - service = request.args.get("service", request.base_url) - - is_allowed_service = (service == request.base_url) or validate_client(service) - - if not is_allowed_service: - return { - "authenticated": False, - "user": None, - "token": None, - "message": "Your service is not authorized for this operation. The API admin needs to add your URL to the list of allowed clients.", - }, HTTPStatus.UNAUTHORIZED - - cas_client = CASClient( - version=2, service_url=service, server_url=config.CAS_SERVER_URL - ) - - if not ticket: - # No ticket, the request come from end user, send to CAS login - cas_login_url = cas_client.get_login_url() - return redirect(cas_login_url) - - # There is a ticket, the request come from CAS as callback. - # need call `verify_ticket()` to validate ticket and get user profile. - user, attributes, pgtiou = cas_client.verify_ticket(ticket) - - if not user: - return { - "authenticated": False, - "user": None, - "token": None, - "message": "Can't verify user", - }, HTTPStatus.UNAUTHORIZED - else: # Login successful - payload = { - "user_id": user, - "sub": user, - "iat": int(datetime.now().timestamp()), - "exp": ( - datetime.now() + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) - ), - } - jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) - return { - "authenticated": True, - "user": user, - "token": jwt_token, - "message": "Token generated", - }, HTTPStatus.OK - - -class GithubOauth2Login(MethodView): - decorators = [with_log_run_times(logger, "GHOAUTH2LOGIN")] - - def get(self): - """ - Authenticate and authorize user using a cofigured GitHub Oauth app. - The authorization in done by verifying users membership to either a Github Team - or Organization. - """ - - code = request.args.get("code") - - if not code: - return { - "authenticated": False, - "user": None, - "token": None, - "message": "Can't verify user", - }, HTTPStatus.UNAUTHORIZED - - resp_obj: dict = requests.post( - "https://github.com/login/oauth/access_token", - data={ - "client_id": config.GITHUB_OAUTH_CLIENT_ID, - "client_secret": config.GITHUB_OAUTH_CLIENT_SECRET, - "code": code, - }, - headers={"Accept": "application/json"}, - ).json() - access_token = resp_obj.get("access_token") - - user_info: dict = requests.get( - "https://api.github.com/user", - headers={"Authorization": f"Bearer {access_token}"}, - ).json() - username = user_info.get("login") - - if not username: - return { - "authenticated": False, - "user": None, - "token": None, - "message": "Couldn't find user on GitHub", - }, HTTPStatus.UNAUTHORIZED - - # Whitelist organization team - if ( - config.GITHUB_OAUTH_WHITELIST_ORGANIZATION - and config.GITHUB_OAUTH_WHITELIST_TEAM - ): - org_resp = requests.get( - f"https://api.github.com/orgs/{config.GITHUB_OAUTH_WHITELIST_ORGANIZATION}/teams/{config.GITHUB_OAUTH_WHITELIST_TEAM}/memberships/{username}", - headers={"Authorization": f"Bearer {access_token}"}, - ) - membership: dict = org_resp.json() - is_member = ( - org_resp.status_code == 200 and membership.get("state") == "active" - ) # https://docs.github.com/en/rest/teams/members?apiVersion=2022-11-28#get-team-membership-for-a-user - elif ( - config.GITHUB_OAUTH_WHITELIST_ORGANIZATION - ): # Whitelist all organization (no team) - org_resp = requests.get( - f"https://api.github.com/orgs/{config.GITHUB_OAUTH_WHITELIST_ORGANIZATION}/members/{username}", - headers={"Authorization": f"Bearer {access_token}"}, - ) - is_member = ( - org_resp.status_code == 204 - ) # https://docs.github.com/en/rest/orgs/members?apiVersion=2022-11-28#check-organization-membership-for-a-user - else: # No authorization check - is_member = True - - # Login successful - if is_member: - payload = { - "user_id": username, - "sub": username, - "iat": int(datetime.now().timestamp()), - "exp": ( - datetime.now() + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) - ), - } - jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) - return { - "authenticated": True, - "user": username, - "token": jwt_token, - "message": "Token generated", - }, HTTPStatus.OK - else: # UNAUTHORIZED - return { - "authenticated": False, - "user": None, - "token": None, - "message": "User is not member of organization or team", - }, HTTPStatus.UNAUTHORIZED - - -class AuthJWTVerify(MethodView): - decorators = [ - with_auth_token(threshold=ProtectionLevels.NONE, response_on_fail=False), - with_log_run_times(logger, "JWTVRF"), - ] - - def get(self, user_id: Optional[str] = None): - """ - Verify JWT endpoint. - """ - return { - "authenticated": True if user_id else False, - "user": user_id, - }, (HTTPStatus.OK if user_id else HTTPStatus.UNAUTHORIZED) - - -class ExperimentView(MethodView): - # IMPORTANT: Remember that in MethodView last decorator is executed first - decorators = [with_auth_token(), with_log_run_times(logger, "SEARCH4")] - - def get(self, user_id: Optional[str] = None): - """ - Search experiments view targeted to handle args - """ - # Parse args - logger.debug("Search args: " + str(request.args)) - - query = request.args.get("query") - only_active = request.args.get("only_active") == "true" - owner = request.args.get("owner") - exp_type = request.args.get("exp_type") - autosubmit_version = request.args.get("autosubmit_version") - - order_by = request.args.get("order_by") - order_desc = request.args.get("order_desc") == "true" - - try: - page = max(request.args.get("page", default=1, type=int), 1) - page_size = request.args.get( - "page_size", default=PAGINATION_LIMIT_DEFAULT, type=int - ) - if page_size > 0: - offset = (page - 1) * page_size - else: - page_size = None - offset = None - except Exception: - return {"error": {"message": "Invalid params"}}, HTTPStatus.BAD_REQUEST - - # Query - statement = generate_query_listexp_extended( - query=query, - only_active=only_active, - owner=owner, - exp_type=exp_type, - autosubmit_version=autosubmit_version, - order_by=order_by, - order_desc=order_desc, - ) - with create_main_db_conn() as conn: - query_result, total_rows = execute_with_limit_offset( - statement=statement, - conn=conn, - limit=page_size, - offset=offset, - ) - - # Process experiments - experiments = [] - for raw_exp in query_result: - exp_builder = ExperimentBuilder() - exp_builder.produce_base_from_dict(raw_exp._mapping) - exp_builder.produce_pkl_modified_time() - exp = exp_builder.product - - # Get current run data from history - # last_modified_timestamp = exp.created - completed = 0 - total = 0 - submitted = 0 - queuing = 0 - running = 0 - failed = 0 - suspended = 0 - try: - current_run = ( - ExperimentHistoryDirector(ExperimentHistoryBuilder(exp.name)) - .build_reader_experiment_history() - .manager.get_experiment_run_dc_with_max_id() - ) - if current_run and current_run.total > 0: - completed = current_run.completed - total = current_run.total - submitted = current_run.submitted - queuing = current_run.queuing - running = current_run.running - failed = current_run.failed - suspended = current_run.suspended - # last_modified_timestamp = current_run.modified_timestamp - except Exception as exc: - logger.warning((f"Exception getting the current run on search: {exc}")) - logger.warning(traceback.format_exc()) - - # Format data - experiments.append( - { - "id": exp.id, - "name": exp.name, - "user": exp.user, - "description": exp.description, - "hpc": exp.hpc, - "version": exp.autosubmit_version, - # "wrapper": exp.wrapper, - "created": exp.created, - "modified": exp.modified, - "status": exp.status if exp.status else "NOT RUNNING", - "completed": completed, - "total": total, - "submitted": submitted, - "queuing": queuing, - "running": running, - "failed": failed, - "suspended": suspended, - } - ) - - # Response - response = { - "experiments": experiments, - "pagination": { - "page": page, - "page_size": page_size, - "total_pages": math.ceil(total_rows / page_size) if page_size else 1, - "page_items": len(experiments), - "total_items": total_rows, - }, - } - return response - - -class ExperimentDetailView(MethodView): - decorators = [with_auth_token(), with_log_run_times(logger, "EXPDETAIL")] - - def get(self, expid: str, user_id: Optional[str] = None): - """ - Get details of an experiment - """ - exp_builder = ExperimentBuilder() - exp_builder.produce_base(expid) - return exp_builder.product.model_dump(include=tables.experiment_table.c.keys()) - - -class ExperimentJobsViewOptEnum(str, Enum): - QUICK = "quick" - BASE = "base" - - -class ExperimentJobsView(MethodView): - decorators = [with_auth_token(), with_log_run_times(logger, "EXPJOBS")] - - def get(self, expid: str, user_id: Optional[str] = None): - """ - Get the experiment jobs from pickle file. - BASE view returns base content of the pkl file. - QUICK view returns a reduced payload with just the name and status of the jobs. - """ - view = request.args.get( - "view", type=str, default=ExperimentJobsViewOptEnum.BASE - ) - - # Read the pkl - try: - current_content = PklReader(expid).parse_job_list() - except Exception as exc: - error_message = "Error while reading the job list" - logger.error(error_message + f": {exc}") - logger.error(traceback.print_exc()) - return { - "error": {"message": error_message} - }, HTTPStatus.INTERNAL_SERVER_ERROR - - pkl_jobs = deque() - for job_item in current_content: - resp_job = { - "name": job_item.name, - "status": Status.VALUE_TO_KEY.get(job_item.status, Status.UNKNOWN), - } - - if view == ExperimentJobsViewOptEnum.BASE: - resp_job = { - **resp_job, - "priority": job_item.priority, - "section": job_item.section, - "date": ( - job_item.date.date().isoformat() - if isinstance(job_item.date, datetime) - else None - ), - "member": job_item.member, - "chunk": job_item.chunk, - "out_path_local": job_item.out_path_local, - "err_path_local": job_item.err_path_local, - "out_path_remote": job_item.out_path_remote, - "err_path_remote": job_item.err_path_remote, - } - - if job_item.status in [Status.COMPLETED, Status.WAITING, Status.READY]: - pkl_jobs.append(resp_job) - else: - pkl_jobs.appendleft(resp_job) - - return {"jobs": list(pkl_jobs)}, HTTPStatus.OK - - -class ExperimentWrappersView(MethodView): - decorators = [with_auth_token(), with_log_run_times(logger, "WRAPPERS")] - - def get(self, expid: str, user_id: Optional[str] = None): - job_package_reader = JobPackageReader(expid) - job_package_reader.read() - - wrappers_dict: Dict[str, List[str]] = job_package_reader.package_to_jobs - - wrappers = [] - for key, val in wrappers_dict.items(): - wrappers.append({"wrapper_name": key, "job_names": val}) - - logger.debug(wrappers) - return {"wrappers": wrappers} - - -class ExperimentFSConfigView(MethodView): - decorators = [with_auth_token(), with_log_run_times(logger, "EXP_FS_CONFIG")] - - @staticmethod - def _format_config_response( - config: Dict[str, Any], is_as3: bool = False - ) -> Dict[str, Any]: - """ - Format the config response, removing some keys if it's an AS3 config - Also, add a key to indicate if the config is empty - :param config: The config to format - :param is_as3: If the config is an AS3 config - """ - ALLOWED_CONFIG_KEYS = ["conf", "exp", "jobs", "platforms", "proj"] - formatted_config = { - key: config[key] - for key in config - if not is_as3 or (key.lower() in ALLOWED_CONFIG_KEYS) - } - formatted_config["contains_nones"] = not config or ( - None in list(config.values()) - ) - return formatted_config - - def get(self, expid: str, user_id: Optional[str] = None): - """ - Get the filesystem config of an experiment - """ - # Read the config - APIBasicConfig.read() - as_config = AutosubmitConfigResolver( - expid, APIBasicConfig, ConfigParserFactory() - ) - is_as3 = isinstance(as_config._configWrapper, confConfigStrategy) - as_config.reload() - curr_fs_config: Dict[str, Any] = as_config.get_full_config_as_dict() - - # Format the response - response = { - "config": ExperimentFSConfigView._format_config_response( - curr_fs_config, is_as3 - ) - } - return response, HTTPStatus.OK - - -class ExperimentRunsView(MethodView): - decorators = [with_auth_token(), with_log_run_times(logger, "EXP_RUNS")] - - def get(self, expid: str, user_id: Optional[str] = None): - """ - List all the runs of an experiment - It returns minimal information about the runs - """ - try: - experiment_history = ExperimentHistoryDirector( - ExperimentHistoryBuilder(expid) - ).build_reader_experiment_history() - exp_runs = experiment_history.get_experiment_runs() - except Exception: - logger.error("Error while getting experiment runs") - logger.error(traceback.format_exc()) - return { - "message": "Error while getting experiment runs" - }, HTTPStatus.INTERNAL_SERVER_ERROR - - # Format the response - response = {"runs": []} - for run in exp_runs: - response["runs"].append( - { - "run_id": run.run_id, - "start": datetime.fromtimestamp(run.start, timezone.utc).isoformat( - timespec="seconds" - ) - if run.start > 0 - else None, - "finish": datetime.fromtimestamp( - run.finish, timezone.utc - ).isoformat(timespec="seconds") - if run.finish > 0 - else None, - } - ) - - return response, HTTPStatus.OK - - -class ExperimentRunConfigView(MethodView): - decorators = [with_auth_token(), with_log_run_times(logger, "EXP_RUN_CONFIG")] - - def get(self, expid: str, run_id: str, user_id: Optional[str] = None): - """ - Get the config of a specific run of an experiment - """ - historical_db = JobDataStructure(expid, APIBasicConfig) - experiment_run = historical_db.get_experiment_run_by_id(run_id=run_id) - metadata = ( - json.loads(experiment_run.metadata) - if experiment_run and experiment_run.metadata - else {} - ) - - # Format the response - response = { - "run_id": experiment_run.run_id if experiment_run else None, - "config": ExperimentFSConfigView._format_config_response(metadata), - } - return response, HTTPStatus.OK diff --git a/setup.py b/setup.py index 8d83ad370c5beb3a86fb59e3b9c8ae291f114974..5d0ed5cad0d652370c2e2940c10dcb6dd0d812e6 100644 --- a/setup.py +++ b/setup.py @@ -20,10 +20,9 @@ def get_authors(): install_requires = [ - "Flask~=2.2.5", + "fastapi[all]==0.115.0", "pyjwt~=2.8.0", "requests~=2.28.1", - "flask_cors~=3.0.10", "bscearth.utils~=0.5.2", "pydotplus~=2.0.2", "portalocker~=2.6.0", @@ -32,7 +31,7 @@ install_requires = [ "python-dotenv~=1.0.1", "autosubmitconfigparser>=1.0.65", "autosubmit>=3.13", - "Flask-APScheduler~=1.13.1", + "APScheduler~=3.10.4", "gunicorn~=22.0.0", "pydantic~=2.5.2", "SQLAlchemy~=2.0.23", @@ -43,6 +42,7 @@ install_requires = [ test_requires = [ "pytest", "pytest-cov", + "pytest-asyncio", "ruff" ] diff --git a/tests/conftest.py b/tests/conftest.py index 248663bdc72e6ed97e5d099190ce017007a782aa..ea829eb31db2f8a2f0061f65b3d9876fc075f109 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,8 @@ import os import tempfile -from flask import Flask import pytest -from autosubmit_api.app import create_app +from fastapi.testclient import TestClient from autosubmit_api.config.basicConfig import APIBasicConfig from autosubmit_api import config @@ -38,24 +37,11 @@ def fixture_mock_basic_config(request: pytest.FixtureRequest): @pytest.fixture -def fixture_app(fixture_mock_basic_config): - app = create_app() - app.config.update( - { - "TESTING": True, - } - ) - yield app - - -@pytest.fixture -def fixture_client(fixture_app: Flask): - return fixture_app.test_client() - +def fixture_fastapi_client(fixture_mock_basic_config): + from autosubmit_api import app -@pytest.fixture -def fixture_runner(fixture_app: Flask): - return fixture_app.test_cli_runner() + with TestClient(app.app) as client: + yield client # Fixtures sqlite diff --git a/tests/test_auth.py b/tests/test_auth.py index 51f296d730a7d3b3df4bde8e0314810f87c4b639..d7b96e913f58da1d10bc4b3af32fae8bc83ea238 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,12 +1,14 @@ import os from uuid import uuid4 +from fastapi import HTTPException +from fastapi.security import HTTPAuthorizationCredentials import pytest -from autosubmit_api.auth import ProtectionLevels, with_auth_token +from autosubmit_api.auth import ProtectionLevels, auth_token_dependency from autosubmit_api import auth from autosubmit_api.auth.utils import validate_client from autosubmit_api.config.basicConfig import APIBasicConfig from autosubmit_api import config -from tests.utils import custom_return_value, dummy_response +from tests.utils import custom_return_value class TestCommonAuth: @@ -18,13 +20,19 @@ class TestCommonAuth: assert ProtectionLevels.ALL > ProtectionLevels.WRITEONLY assert ProtectionLevels.WRITEONLY > ProtectionLevels.NONE - def test_decorator(self, monkeypatch: pytest.MonkeyPatch): + @pytest.mark.asyncio + async def test_dependency(self, monkeypatch: pytest.MonkeyPatch): """ Test different authorization levels. Setting an AUTHORIZATION_LEVEL=ALL will protect all routes no matter it's protection level. If a route is set with level = NONE, will be always protected. """ + # Invalid credentials + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="invalid_token" + ) + # Test on AuthorizationLevels.ALL monkeypatch.setattr( auth, @@ -32,16 +40,16 @@ class TestCommonAuth: custom_return_value(ProtectionLevels.ALL), ) - _, code = with_auth_token(threshold=ProtectionLevels.ALL)(dummy_response)() - assert code == 401 + with pytest.raises(HTTPException): + await auth_token_dependency(threshold=ProtectionLevels.ALL)(credentials) - _, code = with_auth_token(threshold=ProtectionLevels.WRITEONLY)( - dummy_response - )() - assert code == 401 + with pytest.raises(HTTPException): + await auth_token_dependency(threshold=ProtectionLevels.WRITEONLY)( + credentials + ) - _, code = with_auth_token(threshold=ProtectionLevels.NONE)(dummy_response)() - assert code == 401 + with pytest.raises(HTTPException): + await auth_token_dependency(threshold=ProtectionLevels.NONE)(credentials) # Test on AuthorizationLevels.WRITEONLY monkeypatch.setattr( @@ -50,16 +58,18 @@ class TestCommonAuth: custom_return_value(ProtectionLevels.WRITEONLY), ) - _, code = with_auth_token(threshold=ProtectionLevels.ALL)(dummy_response)() - assert code == 200 + assert ( + await auth_token_dependency(threshold=ProtectionLevels.ALL)(credentials) + is None + ) - _, code = with_auth_token(threshold=ProtectionLevels.WRITEONLY)( - dummy_response - )() - assert code == 401 + with pytest.raises(HTTPException): + await auth_token_dependency(threshold=ProtectionLevels.WRITEONLY)( + credentials + ) - _, code = with_auth_token(threshold=ProtectionLevels.NONE)(dummy_response)() - assert code == 401 + with pytest.raises(HTTPException): + await auth_token_dependency(threshold=ProtectionLevels.NONE)(credentials) # Test on AuthorizationLevels.NONE monkeypatch.setattr( @@ -68,16 +78,20 @@ class TestCommonAuth: custom_return_value(ProtectionLevels.NONE), ) - _, code = with_auth_token(threshold=ProtectionLevels.ALL)(dummy_response)() - assert code == 200 + assert ( + await auth_token_dependency(threshold=ProtectionLevels.ALL)(credentials) + is None + ) - _, code = with_auth_token(threshold=ProtectionLevels.WRITEONLY)( - dummy_response - )() - assert code == 200 + assert ( + await auth_token_dependency(threshold=ProtectionLevels.WRITEONLY)( + credentials + ) + is None + ) - _, code = with_auth_token(threshold=ProtectionLevels.NONE)(dummy_response)() - assert code == 401 + with pytest.raises(HTTPException): + await auth_token_dependency(threshold=ProtectionLevels.NONE)(credentials) def test_validate_client( self, monkeypatch: pytest.MonkeyPatch, fixture_mock_basic_config diff --git a/tests/test_endpoints_v3.py b/tests/test_endpoints_v3.py index dd159df0211202d644fdda878f55f089266b1c80..6ab17e82ae20c199b7199ae68e8d10a07923b26c 100644 --- a/tests/test_endpoints_v3.py +++ b/tests/test_endpoints_v3.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from http import HTTPStatus from uuid import uuid4 -from flask.testing import FlaskClient +from fastapi.testclient import TestClient import jwt from autosubmit_api import config import pytest @@ -13,70 +13,70 @@ class TestLogin: def test_not_allowed_client( self, - fixture_client: FlaskClient, + fixture_fastapi_client: TestClient, fixture_mock_basic_config: APIBasicConfig, monkeypatch: pytest.MonkeyPatch, ): monkeypatch.setattr(APIBasicConfig, "ALLOWED_CLIENTS", []) - response = fixture_client.get(self.endpoint) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint) + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.UNAUTHORIZED assert resp_obj.get("authenticated") is False def test_redirect( self, - fixture_client: FlaskClient, + fixture_fastapi_client: TestClient, fixture_mock_basic_config: APIBasicConfig, monkeypatch: pytest.MonkeyPatch, ): random_referer = str(f"https://${str(uuid4())}/") monkeypatch.setattr(APIBasicConfig, "ALLOWED_CLIENTS", [random_referer]) - response = fixture_client.get( - self.endpoint, headers={"Referer": random_referer} + response = fixture_fastapi_client.get( + self.endpoint, headers={"Referer": random_referer}, follow_redirects=False ) - assert response.status_code == HTTPStatus.FOUND - assert config.CAS_LOGIN_URL in response.location - assert random_referer in response.location + assert response.status_code in [HTTPStatus.FOUND, HTTPStatus.TEMPORARY_REDIRECT] + assert config.CAS_LOGIN_URL in response.headers["Location"] + assert random_referer in response.headers["Location"] class TestVerifyToken: endpoint = "/v3/tokentest" - def test_unauthorized_no_token(self, fixture_client: FlaskClient): - response = fixture_client.get(self.endpoint) - resp_obj: dict = response.get_json() + def test_unauthorized_no_token(self, fixture_fastapi_client: TestClient): + response = fixture_fastapi_client.get(self.endpoint) + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.UNAUTHORIZED assert resp_obj.get("isValid") is False - def test_unauthorized_random_token(self, fixture_client: FlaskClient): + def test_unauthorized_random_token(self, fixture_fastapi_client: TestClient): random_token = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint, headers={"Authorization": random_token} ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.UNAUTHORIZED assert resp_obj.get("isValid") is False - def test_authorized(self, fixture_client: FlaskClient): + def test_authorized(self, fixture_fastapi_client: TestClient): random_user = str(uuid4()) payload = { "user_id": random_user, "exp": ( - datetime.utcnow() + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) + datetime.now(timezone.utc) + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) ), } jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) - response = fixture_client.get( - self.endpoint, headers={"Authorization": jwt_token} + response = fixture_fastapi_client.get( + self.endpoint, headers={"Authorization": f"Bearer {jwt_token}"} ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.OK assert resp_obj.get("isValid") is True @@ -85,19 +85,20 @@ class TestVerifyToken: class TestExpInfo: endpoint = "/v3/expinfo/{expid}" - def test_info(self, fixture_client: FlaskClient): + def test_info(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() + assert response.status_code == HTTPStatus.OK assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["expid"] == expid assert resp_obj["total_jobs"] == 8 - def test_retro3_info(self, fixture_client: FlaskClient): + def test_retro3_info(self, fixture_fastapi_client: TestClient): expid = "a3tb" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["expid"] == expid @@ -108,13 +109,13 @@ class TestExpInfo: class TestPerformance: endpoint = "/v3/performance/{expid}" - def test_parallelization(self, fixture_client: FlaskClient): + def test_parallelization(self, fixture_fastapi_client: TestClient): """ Test parallelization without PROCESSORS_PER_NODE """ expid = "a007" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["Parallelization"] == 8 @@ -123,8 +124,8 @@ class TestPerformance: ) expid = "a3tb" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["Parallelization"] == 768 @@ -132,13 +133,13 @@ class TestPerformance: resp_obj["not_considered"], list ) - def test_parallelization_platforms(self, fixture_client: FlaskClient): + def test_parallelization_platforms(self, fixture_fastapi_client: TestClient): """ Test parallelization that comes from default platform """ expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["Parallelization"] == 16 @@ -150,14 +151,14 @@ class TestPerformance: class TestTree: endpoint = "/v3/tree/{expid}" - def test_tree(self, fixture_client: FlaskClient): + def test_tree(self, fixture_fastapi_client: TestClient): expid = "a003" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -166,14 +167,14 @@ class TestTree: for job in resp_obj["jobs"]: assert job["id"][:4] == expid - def test_retro3(self, fixture_client: FlaskClient): + def test_retro3(self, fixture_fastapi_client: TestClient): expid = "a3tb" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -185,10 +186,10 @@ class TestTree: for job in resp_obj["jobs"]: assert job["id"][:4] == expid - def test_wrappers(self, fixture_client: FlaskClient): + def test_wrappers(self, fixture_fastapi_client: TestClient): expid = "a6zj" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert len(resp_obj["jobs"]) == 10 @@ -206,11 +207,11 @@ class TestTree: class TestRunsList: endpoint = "/v3/runs/{expid}" - def test_runs_list(self, fixture_client: FlaskClient): + def test_runs_list(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -220,11 +221,13 @@ class TestRunsList: class TestRunDetail: endpoint = "/v3/rundetail/{expid}/{runId}" - def test_runs_detail(self, fixture_client: FlaskClient): + def test_runs_detail(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid, runId=2)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get( + self.endpoint.format(expid=expid, runId=2) + ) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -234,10 +237,10 @@ class TestRunDetail: class TestQuick: endpoint = "/v3/quick/{expid}" - def test_quick(self, fixture_client: FlaskClient): + def test_quick(self, fixture_fastapi_client: TestClient): expid = "a007" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -248,102 +251,102 @@ class TestQuick: class TestGraph: endpoint = "/v3/graph/{expid}/{graph_type}/{grouped}" - def test_graph_standard_none(self, fixture_client: FlaskClient): + def test_graph_standard_none(self, fixture_fastapi_client: TestClient): expid = "a003" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, graph_type="standard", grouped="none"), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_graph_standard_datemember(self, fixture_client: FlaskClient): + def test_graph_standard_datemember(self, fixture_fastapi_client: TestClient): expid = "a003" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format( expid=expid, graph_type="standard", grouped="date-member" ), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_graph_standard_status(self, fixture_client: FlaskClient): + def test_graph_standard_status(self, fixture_fastapi_client: TestClient): expid = "a003" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, graph_type="standard", grouped="status"), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_graph_laplacian_none(self, fixture_client: FlaskClient): + def test_graph_laplacian_none(self, fixture_fastapi_client: TestClient): expid = "a003" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, graph_type="laplacian", grouped="none"), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_graph_standard_none_retro3(self, fixture_client: FlaskClient): + def test_graph_standard_none_retro3(self, fixture_fastapi_client: TestClient): expid = "a3tb" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, graph_type="standard", grouped="none"), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_graph_standard_datemember_retro3(self, fixture_client: FlaskClient): + def test_graph_standard_datemember_retro3(self, fixture_fastapi_client: TestClient): expid = "a3tb" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format( expid=expid, graph_type="standard", grouped="date-member" ), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_graph_standard_status_retro3(self, fixture_client: FlaskClient): + def test_graph_standard_status_retro3(self, fixture_fastapi_client: TestClient): expid = "a3tb" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, graph_type="standard", grouped="status"), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False assert resp_obj["total_jobs"] == len(resp_obj["nodes"]) - def test_wrappers(self, fixture_client: FlaskClient): + def test_wrappers(self, fixture_fastapi_client: TestClient): expid = "a6zj" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, graph_type="standard", grouped="none"), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert len(resp_obj["nodes"]) == 10 @@ -360,10 +363,10 @@ class TestGraph: class TestExpCount: endpoint = "/v3/expcount/{expid}" - def test_exp_count(self, fixture_client: FlaskClient): + def test_exp_count(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -374,10 +377,10 @@ class TestExpCount: assert resp_obj["counters"]["READY"] == 1 assert resp_obj["counters"]["WAITING"] == 7 - def test_retro3(self, fixture_client: FlaskClient): + def test_retro3(self, fixture_fastapi_client: TestClient): expid = "a3tb" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -395,14 +398,14 @@ class TestExpCount: class TestSummary: endpoint = "/v3/summary/{expid}" - def test_summary(self, fixture_client: FlaskClient): + def test_summary(self, fixture_fastapi_client: TestClient): expid = "a007" random_user = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid), - query_string={"loggedUser": random_user}, + params={"loggedUser": random_user}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -412,12 +415,12 @@ class TestSummary: class TestStatistics: endpoint = "/v3/stats/{expid}/{period}/{section}" - def test_period_none(self, fixture_client: FlaskClient): + def test_period_none(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, period=0, section="Any") ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -427,10 +430,10 @@ class TestStatistics: class TestCurrentConfig: endpoint = "/v3/cconfig/{expid}" - def test_current_config(self, fixture_client: FlaskClient): + def test_current_config(self, fixture_fastapi_client: TestClient): expid = "a007" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -439,10 +442,10 @@ class TestCurrentConfig: == "4.0.95" ) - def test_retrocomp_v3_conf_files(self, fixture_client: FlaskClient): + def test_retrocomp_v3_conf_files(self, fixture_fastapi_client: TestClient): expid = "a3tb" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -455,10 +458,12 @@ class TestCurrentConfig: class TestPklInfo: endpoint = "/v3/pklinfo/{expid}/{timestamp}" - def test_pkl_info(self, fixture_client: FlaskClient): + def test_pkl_info(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid, timestamp=0)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get( + self.endpoint.format(expid=expid, timestamp=0) + ) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -471,10 +476,12 @@ class TestPklInfo: class TestPklTreeInfo: endpoint = "/v3/pkltreeinfo/{expid}/{timestamp}" - def test_pkl_tree_info(self, fixture_client: FlaskClient): + def test_pkl_tree_info(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid, timestamp=0)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get( + self.endpoint.format(expid=expid, timestamp=0) + ) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -487,10 +494,10 @@ class TestPklTreeInfo: class TestExpRunLog: endpoint = "/v3/exprun/{expid}" - def test_exp_run_log(self, fixture_client: FlaskClient): + def test_exp_run_log(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -500,10 +507,10 @@ class TestExpRunLog: class TestIfRunFromLog: endpoint = "/v3/logrun/{expid}" - def test_run_status_from_log(self, fixture_client: FlaskClient): + def test_run_status_from_log(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -515,10 +522,10 @@ class TestIfRunFromLog: class TestQuickIfRun: endpoint = "/v3/ifrun/{expid}" - def test_quick_run_status(self, fixture_client: FlaskClient): + def test_quick_run_status(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -528,10 +535,10 @@ class TestQuickIfRun: class TestJobLogLines: endpoint = "/v3/joblog/{logfile}" - def test_get_logfile_content(self, fixture_client: FlaskClient): + def test_get_logfile_content(self, fixture_fastapi_client: TestClient): logfile = "a3tb_19930101_fc01_1_SIM.20211201184808.err" - response = fixture_client.get(self.endpoint.format(logfile=logfile)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(logfile=logfile)) + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -546,13 +553,13 @@ class TestJobLogLines: class TestJobHistory: endpoint = "/v3/history/{expid}/{jobname}" - def test_job_history(self, fixture_client: FlaskClient): + def test_job_history(self, fixture_fastapi_client: TestClient): expid = "a3tb" jobname = "a3tb_19930101_fc01_1_SIM" - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid, jobname=jobname) ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["error_message"] == "" assert resp_obj["error"] is False @@ -564,10 +571,10 @@ class TestJobHistory: class TestSearchExpid: endpoint = "/v3/search/{expid}" - def test_search_by_expid(self, fixture_client: FlaskClient): + def test_search_by_expid(self, fixture_fastapi_client: TestClient): expid = "a3tb" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert isinstance(resp_obj["experiment"], list) assert len(resp_obj["experiment"]) > 0 @@ -576,8 +583,8 @@ class TestSearchExpid: class TestRunningExps: endpoint = "/v3/running/" - def test_search_by_expid(self, fixture_client: FlaskClient): - response = fixture_client.get(self.endpoint) - resp_obj: dict = response.get_json() + def test_search_by_expid(self, fixture_fastapi_client: TestClient): + response = fixture_fastapi_client.get(self.endpoint) + resp_obj: dict = response.json() assert isinstance(resp_obj["experiment"], list) diff --git a/tests/test_endpoints_v4.py b/tests/test_endpoints_v4.py index 28ffff26ce53ad5a543d96fbd78fa9089de1bda7..7716caf3bfd6a270d252672d7b0eae6f658e0298 100644 --- a/tests/test_endpoints_v4.py +++ b/tests/test_endpoints_v4.py @@ -1,12 +1,12 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from http import HTTPStatus import random from uuid import uuid4 -from flask.testing import FlaskClient +from fastapi.testclient import TestClient import jwt import pytest from autosubmit_api import config -from autosubmit_api.views.v4 import PAGINATION_LIMIT_DEFAULT, ExperimentJobsViewOptEnum +from autosubmit_api.models.requests import PAGINATION_LIMIT_DEFAULT from tests.utils import custom_return_value @@ -14,65 +14,67 @@ class TestCASV2Login: endpoint = "/v4/auth/cas/v2/login" def test_redirect( - self, fixture_client: FlaskClient, monkeypatch: pytest.MonkeyPatch + self, fixture_fastapi_client: TestClient, monkeypatch: pytest.MonkeyPatch ): random_url = f"https://${str(uuid4())}/" monkeypatch.setattr(config, "CAS_SERVER_URL", random_url) assert random_url == config.CAS_SERVER_URL - response = fixture_client.get(self.endpoint) + response = fixture_fastapi_client.get(self.endpoint, follow_redirects=False) - assert response.status_code == HTTPStatus.FOUND - assert config.CAS_SERVER_URL in response.location + assert response.status_code in [HTTPStatus.FOUND, HTTPStatus.TEMPORARY_REDIRECT] + assert response.has_redirect_location + assert config.CAS_SERVER_URL in response.headers["Location"] def test_invalid_client( - self, fixture_client: FlaskClient, monkeypatch: pytest.MonkeyPatch + self, fixture_fastapi_client: TestClient, monkeypatch: pytest.MonkeyPatch ): monkeypatch.setattr( - "autosubmit_api.views.v4.validate_client", custom_return_value(False) + "autosubmit_api.auth.utils.validate_client", custom_return_value(False) ) - response = fixture_client.get(self.endpoint, query_string={"service": "asd"}) + response = fixture_fastapi_client.get(self.endpoint, params={"service": "asd"}) assert response.status_code == HTTPStatus.UNAUTHORIZED class TestJWTVerify: endpoint = "/v4/auth/verify-token" - def test_unauthorized_no_token(self, fixture_client: FlaskClient): - response = fixture_client.get(self.endpoint) - resp_obj: dict = response.get_json() + def test_unauthorized_no_token(self, fixture_fastapi_client: TestClient): + response = fixture_fastapi_client.get(self.endpoint) + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.UNAUTHORIZED assert resp_obj.get("authenticated") is False assert resp_obj.get("user") is None - def test_unauthorized_random_token(self, fixture_client: FlaskClient): + def test_unauthorized_random_token(self, fixture_fastapi_client: TestClient): random_token = str(uuid4()) - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint, headers={"Authorization": random_token} ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.UNAUTHORIZED assert resp_obj.get("authenticated") is False assert resp_obj.get("user") is None - def test_authorized(self, fixture_client: FlaskClient): + def test_authorized(self, fixture_fastapi_client: TestClient): random_user = str(uuid4()) payload = { "user_id": random_user, "sub": random_user, "iat": int(datetime.now().timestamp()), "exp": ( - datetime.utcnow() + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) + datetime.now(timezone.utc) + + timedelta(seconds=config.JWT_EXP_DELTA_SECONDS) ), } jwt_token = jwt.encode(payload, config.JWT_SECRET, config.JWT_ALGORITHM) - response = fixture_client.get( - self.endpoint, headers={"Authorization": jwt_token} + response = fixture_fastapi_client.get( + self.endpoint, headers={"Authorization": "Bearer " + jwt_token} ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert response.status_code == HTTPStatus.OK assert resp_obj.get("authenticated") is True @@ -82,23 +84,23 @@ class TestJWTVerify: class TestExperimentList: endpoint = "/v4/experiments" - def test_page_size(self, fixture_client: FlaskClient): + def test_page_size(self, fixture_fastapi_client: TestClient): # Default page size - response = fixture_client.get(self.endpoint) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint) + resp_obj: dict = response.json() assert resp_obj["pagination"]["page_size"] == PAGINATION_LIMIT_DEFAULT # Any page size page_size = random.randint(2, 100) - response = fixture_client.get( - self.endpoint, query_string={"page_size": page_size} + response = fixture_fastapi_client.get( + self.endpoint, params={"page_size": page_size} ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert resp_obj["pagination"]["page_size"] == page_size # Unbounded page size - response = fixture_client.get(self.endpoint, query_string={"page_size": -1}) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint, params={"page_size": -1}) + resp_obj: dict = response.json() assert resp_obj["pagination"]["page_size"] is None assert ( resp_obj["pagination"]["page_items"] @@ -111,10 +113,10 @@ class TestExperimentList: class TestExperimentDetail: endpoint = "/v4/experiments/{expid}" - def test_detail(self, fixture_client: FlaskClient): + def test_detail(self, fixture_fastapi_client: TestClient): expid = "a003" - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert resp_obj["id"] == 1 assert resp_obj["name"] == expid @@ -131,14 +133,13 @@ class TestExperimentDetail: class TestExperimentJobs: endpoint = "/v4/experiments/{expid}/jobs" - def test_quick(self, fixture_client: FlaskClient): + def test_quick(self, fixture_fastapi_client: TestClient): expid = "a003" - - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid), - query_string={"view": ExperimentJobsViewOptEnum.QUICK.value}, + params={"view": "quick"}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert len(resp_obj["jobs"]) == 8 @@ -147,14 +148,13 @@ class TestExperimentJobs: assert isinstance(job["name"], str) and job["name"].startswith(expid) assert isinstance(job["status"], str) - def test_base(self, fixture_client: FlaskClient): + def test_base(self, fixture_fastapi_client: TestClient): expid = "a003" - - response = fixture_client.get( + response = fixture_fastapi_client.get( self.endpoint.format(expid=expid), - query_string={"view": ExperimentJobsViewOptEnum.BASE.value}, + params={"view": "base"}, ) - resp_obj: dict = response.get_json() + resp_obj: dict = response.json() assert len(resp_obj["jobs"]) == 8 @@ -167,11 +167,10 @@ class TestExperimentJobs: class TestExperimentWrappers: endpoint = "/v4/experiments/{expid}/wrappers" - def test_wrappers(self, fixture_client: FlaskClient): + def test_wrappers(self, fixture_fastapi_client: TestClient): expid = "a6zj" - - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert isinstance(resp_obj, dict) assert isinstance(resp_obj["wrappers"], list) @@ -188,11 +187,10 @@ class TestExperimentWrappers: class TestExperimentFSConfig: endpoint = "/v4/experiments/{expid}/filesystem-config" - def test_fs_config(self, fixture_client: FlaskClient): + def test_fs_config(self, fixture_fastapi_client: TestClient): expid = "a6zj" - - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert isinstance(resp_obj, dict) assert isinstance(resp_obj["config"], dict) @@ -204,11 +202,10 @@ class TestExperimentFSConfig: assert isinstance(resp_obj["config"]["WRAPPERS"], dict) assert isinstance(resp_obj["config"]["WRAPPERS"]["WRAPPER_V"], dict) - def test_fs_config_v3_retro(self, fixture_client: FlaskClient): + def test_fs_config_v3_retro(self, fixture_fastapi_client: TestClient): expid = "a3tb" - - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert isinstance(resp_obj, dict) assert isinstance(resp_obj["config"], dict) @@ -228,9 +225,9 @@ class TestExperimentRuns: endpoint = "/v4/experiments/{expid}/runs" @pytest.mark.parametrize("expid, num_runs", [("a6zj", 1), ("a3tb", 51)]) - def test_runs(self, expid: str, num_runs: int, fixture_client: FlaskClient): - response = fixture_client.get(self.endpoint.format(expid=expid)) - resp_obj: dict = response.get_json() + def test_runs(self, expid: str, num_runs: int, fixture_fastapi_client: TestClient): + response = fixture_fastapi_client.get(self.endpoint.format(expid=expid)) + resp_obj: dict = response.json() assert isinstance(resp_obj, dict) assert isinstance(resp_obj["runs"], list) @@ -246,12 +243,13 @@ class TestExperimentRuns: class TestExperimentRunConfig: endpoint = "/v4/experiments/{expid}/runs/{run_id}/config" - def test_run_config(self, fixture_client: FlaskClient): + def test_run_config(self, fixture_fastapi_client: TestClient): expid = "a6zj" run_id = 1 - - response = fixture_client.get(self.endpoint.format(expid=expid, run_id=run_id)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get( + self.endpoint.format(expid=expid, run_id=run_id) + ) + resp_obj: dict = response.json() assert isinstance(resp_obj, dict) assert isinstance(resp_obj["config"], dict) @@ -264,11 +262,12 @@ class TestExperimentRunConfig: assert isinstance(resp_obj["config"]["WRAPPERS"]["WRAPPER_V"], dict) @pytest.mark.parametrize("run_id", [51, 48, 31]) - def test_run_config_v3_retro(self, run_id: int, fixture_client: FlaskClient): + def test_run_config_v3_retro(self, run_id: int, fixture_fastapi_client: TestClient): expid = "a3tb" - - response = fixture_client.get(self.endpoint.format(expid=expid, run_id=run_id)) - resp_obj: dict = response.get_json() + response = fixture_fastapi_client.get( + self.endpoint.format(expid=expid, run_id=run_id) + ) + resp_obj: dict = response.json() assert isinstance(resp_obj, dict) assert isinstance(resp_obj["config"], dict)