code.lukegrehan.com envbudget / master backend / src / app.py
master

Tree @master (Download .tar.gz)

app.py @masterraw · history · blame

import os
import sqlite3
from flask import Flask, g, jsonify, request, abort
from flask_cors import CORS, cross_origin
import threading
from dataclasses import dataclass

DB_PATH = os.environ.get('DB_PATH') or "test.db"

app = Flask(__name__)
cors = CORS(app)
app.config['CORS_HEADERS'] = "Content-Type"

app.logger.info(f"DB_PATH={DB_PATH}")
def mkDict(cursor, row):
    return dict((cursor.description[idx][0], value) for idx, value in enumerate(row))

def get_db():
    db = getattr(g, '_db', None)
    if (db is None):
        db = g._db = sqlite3.connect(DB_PATH)
        db.row_factory = mkDict
    return db

@app.teardown_appcontext
def close_db(exception):
    db = getattr(g, '_db', None)
    if (db is not None):
        db.close()

def query(query, *args, one=False, rowid=False):
    db = get_db()
    with db:
        try:
            cur = get_db().execute(query, args)
            rv = cur.fetchall()
            if (rowid):
                return cur.lastrowid
        finally:
            cur.close()
    return (rv[0] if rv else None) if one else rv

#----------------------------------------#

def updateDb(envs):
    def _run():
        with app.app_context():
            app.logger.info(f"beginning update...")
            with get_db():
                for env in envs:
                    if ((rowId := env.get("rowid", None)) != None):
                        app.logger.info(f"updating row {rowId}...")
                        if ((name := env.get("name", None)) != None):
                            query("UPDATE Envelopes SET name = ? WHERE rowId = ?;", name, rowId)
                        if ((total := env.get("total", None)) != None):
                            query("UPDATE Envelopes SET total = ? WHERE rowId = ?;", total, rowId)
                        if ((val := env.get("val", None)) != None):
                            query("UPDATE Envelopes SET val = ? WHERE rowId = ?;", val, rowId)
                        if ((order := env.get("ord", None)) != None):
                            query("UPDATE Envelopes SET ord = ? WHERE rowId = ?;", order, rowId)
            app.logger.info(f"update finished")
    threading.Thread(target=_run).start()

#----------------------------------------#

@app.route("/envs", methods=["GET"])
@cross_origin()
def get_envs():
    app.logger.debug("get all")
    return jsonify(query("SELECT rowid, name, total, val, ord FROM Envelopes ORDER BY ord;"))

@app.route("/envs", methods=["POST"])
@cross_origin()
def new_env():
    data = request.get_json()
    name = data.get("name", None)
    total = data.get("total", None)
    val = data.get("val", None)
    if (None in (name, val)):
        abort(400)

    app.logger.info(f"adding env `{name}`")
    last = (query("SELECT ord FROM Envelopes ORDER BY ord DESC;", one=True) or {}).get("ord", 0) +1
    rowId = query("INSERT INTO Envelopes (name, val, total, ord) VALUES (?,?,?,?);", name, val, total, last, rowid=True)
    app.logger.info(f"done adding env `{name}`")
    return str(rowId), 202

@app.route("/envs", methods=["PATCH"])
@cross_origin()
def update_envs():
    data = request.get_json()
    updateDb(data)
    return "", 202

@app.route("/envs/<rowid>", methods=["DELETE"])
@cross_origin()
def delete_env(rowid):
    envExists = 1 in query("SELECT EXISTS(SELECT 1 FROM Envelopes WHERE rowid = ? LIMIT 1)", rowid, one=True).values()
    if (not envExists):
        app.logger.warn(f"env {rowid} doesn't exist")
        abort(404)
    app.logger.warn(f"deleting env {rowid}")
    query("DELETE FROM Envelopes WHERE rowid = ?", rowid)
    return "", 204

@app.route("/")
@app.route("/<path:path>")
def staticFallback(path='index.html'):
    app.logger.debug("find path")
    fullPath = os.path.join(app.static_folder, path)
    app.logger.debug(f"fullPath: {fullPath}")
    if (os.path.exists(fullPath)):
        app.logger.debug("path exists")
        return app.send_static_file(path)
    app.logger.debug("path does not exist")
    return "", 404

if (not os.path.exists(DB_PATH)):
    with app.app_context():
        db = get_db()
        with app.open_resource('schema.sql', mode='r') as f:
            db.cursor().executescript(f.read())
        db.commit()