diff --git a/app/models.py b/app/models.py index 13c9652..ea9d0ba 100644 --- a/app/models.py +++ b/app/models.py @@ -1,7 +1,6 @@ from app import db from sqlalchemy import Column, INTEGER, String, select import datetime - class Vendor(db.Model): __tablename__ = 'vendor' id = Column('id', INTEGER(), primary_key=True, autoincrement=True) @@ -12,7 +11,7 @@ class Vendor(db.Model): return db.session.execute(select(BudgetCategory).where(BudgetCategory.id==self.bc_id)).scalar() def get_line_items(self): - return db.session.execute(select(LineItem).where(LineItem.vendor_id==self.id)).all() + return db.session.execute(select(LineItem).where(LineItem.vendor_id==self.id)).scalars().all() def __repr__(self) -> str: return f'{self.name}' @@ -23,7 +22,7 @@ class BudgetCategory(db.Model): name = Column('name', String(), nullable=False) def get_vendors(self): - return db.session.execute(select(Vendor).where(Vendor.bc_id==self.id)).all() + return db.session.execute(select(Vendor).where(Vendor.bc_id==self.id)).scalars().all() def get_line_items(self): line_items = [] @@ -32,6 +31,17 @@ class BudgetCategory(db.Model): line_items += vendor.get_line_items() return line_items + def get_total_month_cost(self, month:int, year:int): + from .utils import get_month_timestamps + first_ts, last_ts = get_month_timestamps(month, year) + lis = self.get_line_items() + requested_month_lis = 0 + for li in lis: + if first_ts < li.date < last_ts: + requested_month_lis += li.amount + return requested_month_lis + + class LineItem(db.Model): __tablename__ = 'line_item' id = Column('id', INTEGER(), primary_key=True, autoincrement=True)