Skip to content

Commit fbaaac8

Browse files
committed
Implement account.Update w/ follow logic
1 parent c1189d0 commit fbaaac8

6 files changed

Lines changed: 139 additions & 7 deletions

File tree

backend/command/follow.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,33 @@ func (cmd *Command) FollowBlog(accountID uuid.UUID, blogID uuid.UUID) error {
1313
return err
1414
}
1515

16-
err = account.FollowBlog(blogID)
16+
blog, err := tx.Blog().Read(blogID)
17+
if err != nil {
18+
return err
19+
}
20+
21+
err = account.FollowBlog(blog)
22+
if err != nil {
23+
return err
24+
}
25+
26+
return tx.Account().Update(account)
27+
})
28+
}
29+
30+
func (cmd *Command) UnfollowBlog(accountID uuid.UUID, blogID uuid.UUID) error {
31+
return cmd.repo.WithTransaction(func(tx *repository.Repository) error {
32+
account, err := tx.Account().Read(accountID)
33+
if err != nil {
34+
return err
35+
}
36+
37+
blog, err := tx.Blog().Read(blogID)
38+
if err != nil {
39+
return err
40+
}
41+
42+
err = account.UnfollowBlog(blog)
1743
if err != nil {
1844
return err
1945
}

backend/model/account.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,22 @@ func (a *Account) FollowedBlogIDs() []uuid.UUID {
6868
return a.followedBlogIDs
6969
}
7070

71-
func (a *Account) FollowBlog(blogID uuid.UUID) error {
72-
if slices.Contains(a.followedBlogIDs, blogID) {
71+
func (a *Account) FollowBlog(blog *Blog) error {
72+
if slices.Contains(a.followedBlogIDs, blog.ID()) {
7373
return nil
7474
}
7575

76-
a.followedBlogIDs = append(a.followedBlogIDs, blogID)
76+
a.followedBlogIDs = append(a.followedBlogIDs, blog.ID())
7777
return nil
7878
}
7979

80-
func (a *Account) UnfollowBlog(blogID uuid.UUID) error {
81-
if !slices.Contains(a.followedBlogIDs, blogID) {
80+
func (a *Account) UnfollowBlog(blog *Blog) error {
81+
index := slices.Index(a.followedBlogIDs, blog.ID())
82+
if index == -1 {
8283
return nil
8384
}
8485

85-
a.followedBlogIDs = slices.Delete(a.followedBlogIDs, slices.Index(a.followedBlogIDs, blogID), slices.Index(a.followedBlogIDs, blogID)+1)
86+
a.followedBlogIDs = slices.Delete(a.followedBlogIDs, index, index+1)
8687
return nil
8788
}
8889

backend/repository/account.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/sha256"
66
"encoding/hex"
7+
"slices"
78
"time"
89

910
"github.com/google/uuid"
@@ -229,8 +230,59 @@ func (r *AccountRepository) List(limit, offset int) ([]*model.Account, error) {
229230

230231
func (r *AccountRepository) Update(account *model.Account) error {
231232
// List blogs currently being followed in the database.
233+
stmt := `
234+
SELECT
235+
account_blog.blog_id
236+
FROM account_blog
237+
WHERE account_blog.account_id = $1`
238+
239+
rows, err := QueryWithTimeout(r.conn, stmt, account.ID())
240+
if err != nil {
241+
return err
242+
}
243+
244+
followedBlogIDs, err := pgx.CollectRows(rows, pgx.RowTo[uuid.UUID])
245+
if err != nil {
246+
return postgres.CheckListError(err)
247+
}
248+
232249
// Set diff to find which blogs to add or remove.
250+
var blogsToFollow []uuid.UUID
251+
for _, blogID := range account.FollowedBlogIDs() {
252+
if !slices.Contains(followedBlogIDs, blogID) {
253+
blogsToFollow = append(blogsToFollow, blogID)
254+
}
255+
}
256+
257+
var blogsToUnfollow []uuid.UUID
258+
for _, blogID := range followedBlogIDs {
259+
if !slices.Contains(account.FollowedBlogIDs(), blogID) {
260+
blogsToUnfollow = append(blogsToUnfollow, blogID)
261+
}
262+
}
263+
233264
// Add and remove blogs as necessary.
265+
stmtFollow := `
266+
INSERT INTO account_blog
267+
(account_id, blog_id)
268+
VALUES ($1, $2)`
269+
for _, blogID := range blogsToFollow {
270+
err = ExecWithTimeout(r.conn, stmtFollow, account.ID(), blogID)
271+
if err != nil {
272+
return postgres.CheckCreateError(err)
273+
}
274+
}
275+
276+
stmtUnfollow := `
277+
DELETE FROM account_blog
278+
WHERE account_id = $1 AND blog_id = $2`
279+
for _, blogID := range blogsToUnfollow {
280+
err = ExecWithTimeout(r.conn, stmtUnfollow, account.ID(), blogID)
281+
if err != nil {
282+
return postgres.CheckDeleteError(err)
283+
}
284+
}
285+
234286
return nil
235287
}
236288

backend/repository/account_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,33 @@ func TestAccountList(t *testing.T) {
9090
test.AssertEqual(t, len(accounts), limit)
9191
}
9292

93+
func TestAccountUpdate(t *testing.T) {
94+
t.Parallel()
95+
96+
repo, closer := test.NewRepository(t)
97+
defer closer()
98+
99+
account := test.CreateAccount(t, repo)
100+
blog := test.CreateBlog(t, repo)
101+
102+
account.FollowBlog(blog)
103+
err := repo.Account().Update(account)
104+
test.AssertNilError(t, err)
105+
106+
updatedAccount, err := repo.Account().Read(account.ID())
107+
test.AssertNilError(t, err)
108+
109+
test.AssertSliceContains(t, updatedAccount.FollowedBlogIDs(), blog.ID())
110+
111+
account.UnfollowBlog(blog)
112+
err = repo.Account().Update(account)
113+
test.AssertNilError(t, err)
114+
115+
updatedAccount, err = repo.Account().Read(account.ID())
116+
test.AssertNilError(t, err)
117+
test.AssertSliceDoesNotContain(t, updatedAccount.FollowedBlogIDs(), blog.ID())
118+
}
119+
93120
func TestAccountDelete(t *testing.T) {
94121
t.Parallel()
95122

backend/repository/repository.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package repository
33
import (
44
"context"
55

6+
"github.com/jackc/pgx/v5"
7+
68
"github.com/theandrew168/bloggulus/backend/postgres"
79
)
810

@@ -100,3 +102,19 @@ func (r *Repository) WithTransaction(operation func(repo *Repository) error) err
100102

101103
return nil
102104
}
105+
106+
func QueryWithTimeout(conn postgres.Conn, stmt string, args ...any) (pgx.Rows, error) {
107+
ctx, cancel := context.WithTimeout(context.Background(), postgres.Timeout)
108+
defer cancel()
109+
110+
rows, err := conn.Query(ctx, stmt, args...)
111+
return rows, err
112+
}
113+
114+
func ExecWithTimeout(conn postgres.Conn, stmt string, args ...any) error {
115+
ctx, cancel := context.WithTimeout(context.Background(), postgres.Timeout)
116+
defer cancel()
117+
118+
_, err := conn.Exec(ctx, stmt, args...)
119+
return err
120+
}

backend/test/assert.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ func AssertSliceContains[T comparable](t *testing.T, got []T, want T) {
4848
}
4949
}
5050

51+
func AssertSliceDoesNotContain[T comparable](t *testing.T, got []T, want T) {
52+
t.Helper()
53+
54+
if slices.Contains(got, want) {
55+
t.Fatalf("got %v; should not contain: %v", got, want)
56+
}
57+
}
58+
5159
func AssertNilError(t *testing.T, got error) {
5260
t.Helper()
5361

0 commit comments

Comments
 (0)