Skip to content

Commit

Permalink
Update Context Requests (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
aidantrabs authored Dec 12, 2024
2 parents 868d351 + 3618422 commit 26abae3
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 138 deletions.
23 changes: 16 additions & 7 deletions backend/internal/server/company.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package server

import (
"context"
"net/http"

"KonferCA/SPUR/db"
Expand All @@ -14,6 +13,8 @@ import (
)

func (s *Server) handleCreateCompany(c echo.Context) error {
ctx := c.Request().Context()

var req *CreateCompanyRequest
req, ok := c.Get(mw.REQUEST_BODY_KEY).(*CreateCompanyRequest)
if !ok {
Expand All @@ -32,7 +33,7 @@ func (s *Server) handleCreateCompany(c echo.Context) error {
Description: req.Description,
}

company, err := queries.CreateCompany(context.Background(), params)
company, err := queries.CreateCompany(ctx, params)
if err != nil {
return handleDBError(err, "create", "company")
}
Expand All @@ -41,12 +42,14 @@ func (s *Server) handleCreateCompany(c echo.Context) error {
}

func (s *Server) handleGetUserCompany(c echo.Context) error {
ctx := c.Request().Context()

claims, ok := c.Get(middleware.JWT_CLAIMS).(*jwt.JWTClaims)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to type cast jwt claims")
}

company, err := s.queries.GetCompanyByUser(c.Request().Context(), claims.UserID)
company, err := s.queries.GetCompanyByUser(ctx, claims.UserID)
if err != nil {
if isNoRowsError(err) {
return echo.NewHTTPError(http.StatusNotFound, "No company found")
Expand All @@ -59,13 +62,15 @@ func (s *Server) handleGetUserCompany(c echo.Context) error {
}

func (s *Server) handleGetCompany(c echo.Context) error {
ctx := c.Request().Context()

companyID, err := validateUUID(c.Param("id"), "company")
if err != nil {
return err
}

queries := db.New(s.DBPool)
company, err := queries.GetCompanyByID(context.Background(), companyID)
company, err := queries.GetCompanyByID(ctx, companyID)
if err != nil {
return handleDBError(err, "fetch", "company")
}
Expand All @@ -74,8 +79,10 @@ func (s *Server) handleGetCompany(c echo.Context) error {
}

func (s *Server) handleListCompanies(c echo.Context) error {
ctx := c.Request().Context()

queries := db.New(s.DBPool)
companies, err := queries.ListCompanies(context.Background())
companies, err := queries.ListCompanies(ctx)
if err != nil {
return handleDBError(err, "fetch", "companies")
}
Expand All @@ -84,18 +91,20 @@ func (s *Server) handleListCompanies(c echo.Context) error {
}

func (s *Server) handleDeleteCompany(c echo.Context) error {
ctx := c.Request().Context()

companyID, err := validateUUID(c.Param("id"), "company")
if err != nil {
return err
}

queries := db.New(s.DBPool)
_, err = queries.GetCompanyByID(context.Background(), companyID)
_, err = queries.GetCompanyByID(ctx, companyID)
if err != nil {
return handleDBError(err, "verify", "company")
}

err = queries.DeleteCompany(context.Background(), companyID)
err = queries.DeleteCompany(ctx, companyID)
if err != nil {
return handleDBError(err, "delete", "company")
}
Expand Down
34 changes: 20 additions & 14 deletions backend/internal/server/company_documents.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package server

import (
"context"
"net/http"

"KonferCA/SPUR/db"
mw "KonferCA/SPUR/internal/middleware"

"github.com/labstack/echo/v4"
)

func (s *Server) handleCreateCompanyDocument(c echo.Context) error {
ctx := c.Request().Context()

var req *CreateCompanyDocumentRequest
req, ok := c.Get(mw.REQUEST_BODY_KEY).(*CreateCompanyDocumentRequest)
if !ok {
Expand All @@ -22,8 +24,7 @@ func (s *Server) handleCreateCompanyDocument(c echo.Context) error {
}

queries := db.New(s.DBPool)

_, err = queries.GetCompanyByID(context.Background(), companyID)
_, err = queries.GetCompanyByID(ctx, companyID)
if err != nil {
return handleDBError(err, "verify", "company")
}
Expand All @@ -34,7 +35,7 @@ func (s *Server) handleCreateCompanyDocument(c echo.Context) error {
FileUrl: req.FileURL,
}

document, err := queries.CreateCompanyDocument(context.Background(), params)
document, err := queries.CreateCompanyDocument(ctx, params)
if err != nil {
return handleDBError(err, "create", "company document")
}
Expand All @@ -43,13 +44,15 @@ func (s *Server) handleCreateCompanyDocument(c echo.Context) error {
}

func (s *Server) handleGetCompanyDocument(c echo.Context) error {
ctx := c.Request().Context()

documentID, err := validateUUID(c.Param("id"), "document")
if err != nil {
return err
}

queries := db.New(s.DBPool)
document, err := queries.GetCompanyDocumentByID(context.Background(), documentID)
document, err := queries.GetCompanyDocumentByID(ctx, documentID)
if err != nil {
return handleDBError(err, "fetch", "company document")
}
Expand All @@ -58,28 +61,29 @@ func (s *Server) handleGetCompanyDocument(c echo.Context) error {
}

func (s *Server) handleListCompanyDocuments(c echo.Context) error {
ctx := c.Request().Context()

companyID, err := validateUUID(c.Param("id"), "company")
if err != nil {
return err
}

queries := db.New(s.DBPool)

documentType := c.QueryParam("document_type")
if documentType != "" {
params := db.ListDocumentsByTypeParams{
CompanyID: companyID,
DocumentType: documentType,
}

documents, err := queries.ListDocumentsByType(context.Background(), params)
documents, err := queries.ListDocumentsByType(ctx, params)
if err != nil {
return handleDBError(err, "fetch", "company documents")
}
return c.JSON(http.StatusOK, documents)
}

documents, err := queries.ListCompanyDocuments(context.Background(), companyID)
documents, err := queries.ListCompanyDocuments(ctx, companyID)
if err != nil {
return handleDBError(err, "fetch", "company documents")
}
Expand All @@ -88,6 +92,8 @@ func (s *Server) handleListCompanyDocuments(c echo.Context) error {
}

func (s *Server) handleUpdateCompanyDocument(c echo.Context) error {
ctx := c.Request().Context()

documentID, err := validateUUID(c.Param("id"), "document")
if err != nil {
return err
Expand All @@ -100,8 +106,7 @@ func (s *Server) handleUpdateCompanyDocument(c echo.Context) error {
}

queries := db.New(s.DBPool)

_, err = queries.GetCompanyDocumentByID(context.Background(), documentID)
_, err = queries.GetCompanyDocumentByID(ctx, documentID)
if err != nil {
return handleDBError(err, "verify", "company document")
}
Expand All @@ -112,7 +117,7 @@ func (s *Server) handleUpdateCompanyDocument(c echo.Context) error {
FileUrl: req.FileURL,
}

document, err := queries.UpdateCompanyDocument(context.Background(), params)
document, err := queries.UpdateCompanyDocument(ctx, params)
if err != nil {
return handleDBError(err, "update", "company document")
}
Expand All @@ -121,19 +126,20 @@ func (s *Server) handleUpdateCompanyDocument(c echo.Context) error {
}

func (s *Server) handleDeleteCompanyDocument(c echo.Context) error {
ctx := c.Request().Context()

documentID, err := validateUUID(c.Param("id"), "document")
if err != nil {
return err
}

queries := db.New(s.DBPool)

_, err = queries.GetCompanyDocumentByID(context.Background(), documentID)
_, err = queries.GetCompanyDocumentByID(ctx, documentID)
if err != nil {
return handleDBError(err, "verify", "company document")
}

err = queries.DeleteCompanyDocument(context.Background(), documentID)
err = queries.DeleteCompanyDocument(ctx, documentID)
if err != nil {
return handleDBError(err, "delete", "company document")
}
Expand Down
33 changes: 19 additions & 14 deletions backend/internal/server/company_financials.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package server

import (
"context"
"net/http"
"strconv"

"KonferCA/SPUR/db"
mw "KonferCA/SPUR/internal/middleware"

"github.com/labstack/echo/v4"
)

func (s *Server) handleCreateCompanyFinancials(c echo.Context) error {
ctx := c.Request().Context()

var req *CreateCompanyFinancialsRequest
req, ok := c.Get(mw.REQUEST_BODY_KEY).(*CreateCompanyFinancialsRequest)
if !ok {
Expand All @@ -23,8 +25,7 @@ func (s *Server) handleCreateCompanyFinancials(c echo.Context) error {
}

queries := db.New(s.DBPool)

_, err = queries.GetCompanyByID(context.Background(), companyID)
_, err = queries.GetCompanyByID(ctx, companyID)
if err != nil {
return handleDBError(err, "verify", "company")
}
Expand All @@ -41,7 +42,7 @@ func (s *Server) handleCreateCompanyFinancials(c echo.Context) error {
GrantsReceived: numericFromFloat(req.GrantsReceived),
}

financials, err := queries.CreateCompanyFinancials(context.Background(), params)
financials, err := queries.CreateCompanyFinancials(ctx, params)
if err != nil {
return handleDBError(err, "create", "company financials")
}
Expand All @@ -50,13 +51,14 @@ func (s *Server) handleCreateCompanyFinancials(c echo.Context) error {
}

func (s *Server) handleGetCompanyFinancials(c echo.Context) error {
ctx := c.Request().Context()

companyID, err := validateUUID(c.Param("id"), "company")
if err != nil {
return err
}

queries := db.New(s.DBPool)

year := c.QueryParam("year")
if year != "" {
yearInt, err := strconv.ParseInt(year, 10, 32)
Expand All @@ -69,15 +71,15 @@ func (s *Server) handleGetCompanyFinancials(c echo.Context) error {
FinancialYear: int32(yearInt),
}

financials, err := queries.GetCompanyFinancialsByYear(context.Background(), params)
financials, err := queries.GetCompanyFinancialsByYear(ctx, params)
if err != nil {
return handleDBError(err, "fetch", "company financials")
}

return c.JSON(http.StatusOK, financials)
}

financials, err := queries.ListCompanyFinancials(context.Background(), companyID)
financials, err := queries.ListCompanyFinancials(ctx, companyID)
if err != nil {
return handleDBError(err, "fetch", "company financials")
}
Expand All @@ -86,6 +88,8 @@ func (s *Server) handleGetCompanyFinancials(c echo.Context) error {
}

func (s *Server) handleUpdateCompanyFinancials(c echo.Context) error {
ctx := c.Request().Context()

var req *CreateCompanyFinancialsRequest
req, ok := c.Get(mw.REQUEST_BODY_KEY).(*CreateCompanyFinancialsRequest)
if !ok {
Expand All @@ -108,8 +112,7 @@ func (s *Server) handleUpdateCompanyFinancials(c echo.Context) error {
}

queries := db.New(s.DBPool)

_, err = queries.GetCompanyByID(context.Background(), companyID)
_, err = queries.GetCompanyByID(ctx, companyID)
if err != nil {
return handleDBError(err, "verify", "company")
}
Expand All @@ -126,7 +129,7 @@ func (s *Server) handleUpdateCompanyFinancials(c echo.Context) error {
GrantsReceived: numericFromFloat(req.GrantsReceived),
}

financials, err := queries.UpdateCompanyFinancials(context.Background(), params)
financials, err := queries.UpdateCompanyFinancials(ctx, params)
if err != nil {
return handleDBError(err, "update", "company financials")
}
Expand All @@ -135,6 +138,8 @@ func (s *Server) handleUpdateCompanyFinancials(c echo.Context) error {
}

func (s *Server) handleDeleteCompanyFinancials(c echo.Context) error {
ctx := c.Request().Context()

companyID, err := validateUUID(c.Param("id"), "company")
if err != nil {
return err
Expand All @@ -151,13 +156,12 @@ func (s *Server) handleDeleteCompanyFinancials(c echo.Context) error {
}

queries := db.New(s.DBPool)

params := db.DeleteCompanyFinancialsParams{
CompanyID: companyID,
FinancialYear: int32(yearInt),
}

err = queries.DeleteCompanyFinancials(context.Background(), params)
err = queries.DeleteCompanyFinancials(ctx, params)
if err != nil {
return handleDBError(err, "delete", "company financials")
}
Expand All @@ -166,14 +170,15 @@ func (s *Server) handleDeleteCompanyFinancials(c echo.Context) error {
}

func (s *Server) handleGetLatestCompanyFinancials(c echo.Context) error {
ctx := c.Request().Context()

companyID, err := validateUUID(c.Param("id"), "company")
if err != nil {
return err
}

queries := db.New(s.DBPool)

financials, err := queries.GetLatestCompanyFinancials(context.Background(), companyID)
financials, err := queries.GetLatestCompanyFinancials(ctx, companyID)
if err != nil {
return handleDBError(err, "fetch", "latest company financials")
}
Expand Down
Loading

0 comments on commit 26abae3

Please sign in to comment.