|
4 | 4 | "context" |
5 | 5 | "crypto/sha256" |
6 | 6 | "encoding/hex" |
| 7 | + "slices" |
7 | 8 | "time" |
8 | 9 |
|
9 | 10 | "github.com/google/uuid" |
@@ -229,8 +230,59 @@ func (r *AccountRepository) List(limit, offset int) ([]*model.Account, error) { |
229 | 230 |
|
230 | 231 | func (r *AccountRepository) Update(account *model.Account) error { |
231 | 232 | // List blogs currently being followed in the database. |
| 233 | + stmt := ` |
| 234 | + SELECT |
| 235 | + account_blog.blog_id |
| 236 | + FROM account_blog |
| 237 | + WHERE account_blog.account_id = $1` |
| 238 | + |
| 239 | + rows, err := QueryWithTimeout(r.conn, stmt, account.ID()) |
| 240 | + if err != nil { |
| 241 | + return err |
| 242 | + } |
| 243 | + |
| 244 | + followedBlogIDs, err := pgx.CollectRows(rows, pgx.RowTo[uuid.UUID]) |
| 245 | + if err != nil { |
| 246 | + return postgres.CheckListError(err) |
| 247 | + } |
| 248 | + |
232 | 249 | // Set diff to find which blogs to add or remove. |
| 250 | + var blogsToFollow []uuid.UUID |
| 251 | + for _, blogID := range account.FollowedBlogIDs() { |
| 252 | + if !slices.Contains(followedBlogIDs, blogID) { |
| 253 | + blogsToFollow = append(blogsToFollow, blogID) |
| 254 | + } |
| 255 | + } |
| 256 | + |
| 257 | + var blogsToUnfollow []uuid.UUID |
| 258 | + for _, blogID := range followedBlogIDs { |
| 259 | + if !slices.Contains(account.FollowedBlogIDs(), blogID) { |
| 260 | + blogsToUnfollow = append(blogsToUnfollow, blogID) |
| 261 | + } |
| 262 | + } |
| 263 | + |
233 | 264 | // Add and remove blogs as necessary. |
| 265 | + stmtFollow := ` |
| 266 | + INSERT INTO account_blog |
| 267 | + (account_id, blog_id) |
| 268 | + VALUES ($1, $2)` |
| 269 | + for _, blogID := range blogsToFollow { |
| 270 | + err = ExecWithTimeout(r.conn, stmtFollow, account.ID(), blogID) |
| 271 | + if err != nil { |
| 272 | + return postgres.CheckCreateError(err) |
| 273 | + } |
| 274 | + } |
| 275 | + |
| 276 | + stmtUnfollow := ` |
| 277 | + DELETE FROM account_blog |
| 278 | + WHERE account_id = $1 AND blog_id = $2` |
| 279 | + for _, blogID := range blogsToUnfollow { |
| 280 | + err = ExecWithTimeout(r.conn, stmtUnfollow, account.ID(), blogID) |
| 281 | + if err != nil { |
| 282 | + return postgres.CheckDeleteError(err) |
| 283 | + } |
| 284 | + } |
| 285 | + |
234 | 286 | return nil |
235 | 287 | } |
236 | 288 |
|
|
0 commit comments