分享
  1. 首页
  2. 文章

Go+typescript+GraphQL+react构建简书网站(三) 编写Model

云燕 · · 1103 次点击 · · 开始浏览
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

补遗:数据库增加Tag表

新建tag表:

CREATE TABLE "public"."tag" (
 "id" int8 NOT NULL,
 "name" varchar(255) NOT NULL,
 "created_at" timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
 "updated_at" timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
 "deleted_at" timestamp(6) NOT NULL,
 PRIMARY KEY ("id")
)
;
COMMENT ON COLUMN "public"."tag"."id" IS 'ID';
COMMENT ON COLUMN "public"."tag"."name" IS '标签名';
COMMENT ON COLUMN "public"."tag"."created_at" IS '创建时间';
COMMENT ON COLUMN "public"."tag"."updated_at" IS '更新时间';
COMMENT ON COLUMN "public"."tag"."deleted_at" IS '删除时间';

这里不得不说一下,由于是一边写代码一边写文章(文章的作用只是用来给自己厘清思路),所以文章中的代码内容很可能下一次就变了,毕竟文章中的代码,只是我初步写时的思路,肯定存在错漏之处,后续会慢慢完善。如要看最新的代码,还请移步:https://github.com/unrotten/h...

编写CURD基础方法

依然先看结果,修改db.go文件:

package model
import (
 "context"
 "database/sql"
 "database/sql/driver"
 "fmt"
 "github.com/jmoiron/sqlx"
 _ "github.com/lib/pq"
 "github.com/rs/zerolog"
 "github.com/sony/sonyflake"
 "github.com/spf13/viper"
 "github.com/unrotten/builder"
 "github.com/unrotten/sqlex"
 "log"
 "os"
 "reflect"
 "time"
)
var (
 DB *sqlx.DB
 psql sqlex.StatementBuilderType
 idfetcher *sonyflake.Sonyflake
)
const defaultSkip int = 2
type cv map[string]interface{}
type where []sqlex.Sqlex
type result struct {
 b builder.Builder
 success bool
}
// 初始化数据库连接
func init() {
 viper.AddConfigPath("../config") // 测试使用
 viper.ReadInConfig()
 // 获取数据库配置信息
 user := viper.Get("storage.user")
 password := viper.Get("storage.password")
 host := viper.Get("storage.host")
 port := viper.Get("storage.port")
 dbname := viper.Get("storage.dbname")
 // 连接数据库
 psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
 host, port, user, password, dbname)
 DB = sqlx.MustOpen("postgres", psqlInfo)
 if err := DB.Ping(); err != nil {
 log.Fatalf("连接数据库失败:%s", err)
 }
 // 初始化sql构建器,指定format形式
 psql = sqlex.StatementBuilder.PlaceholderFormat(sqlex.Dollar)
 sqlex.SetLogger(os.Stdout)
 // 初始化sonyflake
 st := sonyflake.Settings{
 StartTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.Local),
 }
 idfetcher = sonyflake.NewSonyflake(st)
}
func get(query *sql.Rows, columnTypes []*sql.ColumnType, logger zerolog.Logger) result {
 dest := make([]interface{}, len(columnTypes))
 for index, col := range columnTypes {
 switch col.ScanType().String() {
 case "string", "interface {}":
 dest[index] = &sql.NullString{}
 case "bool":
 dest[index] = &sql.NullBool{}
 case "float64":
 dest[index] = &sql.NullFloat64{}
 case "int32":
 dest[index] = &sql.NullInt32{}
 case "int64":
 dest[index] = &sql.NullInt64{}
 case "time.Time":
 dest[index] = &sql.NullTime{}
 default:
 dest[index] = reflect.New(col.ScanType()).Interface()
 }
 }
 err := query.Scan(dest...)
 if err != nil {
 logger.Error().Caller(2).Err(err).Send()
 return result{success: false}
 }
 build := builder.EmptyBuilder
 for index, col := range columnTypes {
 switch val := dest[index].(type) {
 case driver.Valuer:
 var value interface{}
 switch col.ScanType().String() {
 case "string", "interface {}":
 value = dest[index].(*sql.NullString).String
 case "bool":
 value = dest[index].(*sql.NullBool).Bool
 case "float64":
 value = dest[index].(*sql.NullFloat64).Float64
 case "int32":
 value = dest[index].(*sql.NullInt32).Int32
 case "int64":
 value = dest[index].(*sql.NullInt64).Int64
 case "time.Time":
 value = dest[index].(*sql.NullTime).Time
 }
 build = builder.Set(build, col.Name(), value).(builder.Builder)
 default:
 build = builder.Set(build, col.Name(), val).(builder.Builder)
 }
 }
 return result{success: true, b: build}
}
func selectList(ctx context.Context, table string, where where, columns ...string) result {
 logger := ctx.Value("logger").(zerolog.Logger)
 tx := ctx.Value("tx").(*sqlx.Tx)
 var selectBuilder sqlex.SelectBuilder
 if len(columns) > 0 {
 selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is null")
 } else {
 selectBuilder = psql.Select("*").From(table).Where("deleted_at is null")
 }
 for _, arg := range where {
 selectBuilder = selectBuilder.Where(arg)
 }
 query, err := selectBuilder.RunWith(tx).Query()
 if err != nil {
 logger.Error().Caller(1).Err(err).Send()
 return result{success: false}
 }
 columnTypes, err := query.ColumnTypes()
 if err != nil {
 logger.Error().Caller(1).Err(err).Send()
 return result{success: false}
 }
 var resultSlice []interface{}
 for query.Next() {
 r := get(query, columnTypes, logger)
 if !r.success {
 return r
 }
 resultSlice = append(resultSlice, r.b)
 }
 return result{success: true, b: builder.Set(builder.EmptyBuilder, "list", resultSlice).(builder.Builder)}
}
func selectOne(ctx context.Context, table string, where where, columns ...string) result {
 logger := ctx.Value("logger").(zerolog.Logger)
 tx := ctx.Value("tx").(*sqlx.Tx)
 var selectBuilder sqlex.SelectBuilder
 if len(columns) > 0 {
 selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is null").Limit(1)
 } else {
 selectBuilder = psql.Select("*").From(table).Where("deleted_at is null").Limit(1)
 }
 for _, arg := range where {
 selectBuilder = selectBuilder.Where(arg)
 }
 query, err := selectBuilder.RunWith(tx).Query()
 if err != nil {
 logger.Error().Caller(1).Err(err).Send()
 return result{success: false}
 }
 columnTypes, err := query.ColumnTypes()
 if err != nil {
 logger.Error().Caller(1).Err(err).Send()
 return result{success: false}
 }
 if query.Next() {
 return get(query, columnTypes, logger)
 }
 return result{success: false}
}
func selectReal(ctx context.Context, table string, where where, columns ...string) result {
 logger := ctx.Value("logger").(zerolog.Logger)
 tx := ctx.Value("tx").(*sqlx.Tx)
 var selectBuilder sqlex.SelectBuilder
 if len(columns) > 0 {
 selectBuilder = psql.Select(columns...).From(table).Where("deleted_at is not null")
 } else {
 selectBuilder = psql.Select("*").From(table).Where("deleted_at is not null")
 }
 for _, arg := range where {
 selectBuilder = selectBuilder.Where(arg)
 }
 query, err := selectBuilder.RunWith(tx).Query()
 if err != nil {
 logger.Error().Caller(1).Err(err).Send()
 return result{success: false}
 }
 columnTypes, err := query.ColumnTypes()
 if err != nil {
 logger.Error().Caller(1).Err(err).Send()
 return result{success: false}
 }
 var resultSlice []interface{}
 for query.Next() {
 r := get(query, columnTypes, logger)
 if !r.success {
 return r
 }
 resultSlice = append(resultSlice, r.b)
 }
 return result{success: true, b: builder.Set(builder.EmptyBuilder, "list", resultSlice).(builder.Builder)}
}
func insertOne(ctx context.Context, table string, cv cv) result {
 logger := ctx.Value("logger").(zerolog.Logger)
 tx := ctx.Value("tx").(*sqlx.Tx)
 build := builder.EmptyBuilder
 cv["created_at"], cv["updated_at"] = time.Now(), time.Now()
 columns, values := make([]string, 0, len(cv)), make([]interface{}, 0, len(cv))
 for col, value := range cv {
 build = builder.Set(build, col, value).(builder.Builder)
 columns, values = append(columns, col), append(values, value)
 }
 r, err := psql.Insert(table).Columns(columns...).Values(values...).RunWith(tx).Exec()
 return assertSqlResult(r, err, logger)
}
func update(ctx context.Context, table string, cv cv, where where, directSet ...string) result {
 logger := ctx.Value("logger").(zerolog.Logger)
 tx := ctx.Value("tx").(*sqlx.Tx)
 cv["updated_at"] = time.Now()
 updateBuilder := psql.Update(table).SetMap(cv).Where("deleted_at is null")
 for _, set := range directSet {
 updateBuilder = updateBuilder.DirectSet(set)
 }
 for _, arg := range where {
 updateBuilder = updateBuilder.Where(arg)
 }
 r, err := updateBuilder.RunWith(tx).Exec()
 return assertSqlResult(r, err, logger)
}
// note: if where is null,then will delete the whole table
func remove(ctx context.Context, table string, where where) result {
 logger := ctx.Value("logger").(zerolog.Logger)
 tx := ctx.Value("tx").(*sqlx.Tx)
 updateBuilder := psql.Update(table).Set("deleted_at", time.Now()).Where("deleted_at is null")
 for _, arg := range where {
 updateBuilder = updateBuilder.Where(arg)
 }
 r, err := updateBuilder.RunWith(tx).Exec()
 return assertSqlResult(r, err, logger)
}
func assertSqlResult(r sql.Result, err error, logger zerolog.Logger, skip ...int) result {
 sk := defaultSkip
 if len(skip) > 0 {
 sk += skip[0]
 }
 if err != nil {
 logger.Error().Caller(sk).Err(err).Send()
 return result{success: false}
 }
 affected, err := r.RowsAffected()
 if err != nil {
 logger.Error().Caller(2).Err(err).Send()
 return result{success: false}
 }
 if affected == 0 {
 return result{success: false}
 }
 return result{success: true}
}

在这里我们只看查询,selectList和selectOne依托于get方法实现,而get的核心就是设值。因为在数据库中,数据存在NULL的情况,而Go中的基础类型如string,int64等并不支持,所以我们必须使用其对应的sql.NullString等类型去scan。作者这里为了保持model中定义的struct能够继续使用string等基础类型,在get中进行了类型的判断,不可空的基础类型通过两次switch转换,最终即便对于NULL值,也会得到基础类型的默认空值。

在get方法中,我们使用reflect.New(col.ScanType()).Interface()方法,获得字段对应的指针值,这里使用了反射,效果等同于new()。

在记录错误日志logger.Error().Caller(sk).Err(err).Send()时,我们先指定了日志的类别为Error,再调用了Caller(sk),获取运行时上下文。Caller的原理是调用runtime.Caller(skip)方法,以获取指定的代码段位置。最终效果就是通常我们程序报错时,在控制台能够看到的,各个文件的指定行。

在get方法的最后,我们通过builder.Set(build, col.Name(), value).(builder.Builder)这样的代码段,将数据对应的名字和值存入指定的builer中。builder的效果类似于map,只是使用builder库可以更方便直接将map转为指定的struct。

再把目光转到selectOne方法,可以看到我们从上下文context中获取了logger和事务tx,这里是方便后续的工作。我们需要注意的是,sqlex库进行sql构建时,严格按照了sql语法的规定,当然where和from之间的顺序在这里可以不用管。我们在初始化selectBuilder的时候,Where("1=1")给定了一个初始的where条件,这样做的用意是,由于sqlex库提供了IF操作,譬如:

 psql.Select("*").From("user").Where(sqlex.IF{Condition: "a" == "", Sq: sqlex.Eq{"a": "3"}})

这样的代码,由于"a"==""不满足,所以IF中的"a"=="3"并不会被纳入构建器中,可是也因为调用了Where,所以构建器中sql中必然会增加一个where,最终得到错误的sql:SELECT * FROM "user" WHERE

编写Model

model目录下新建user.go文件:

package model
import (
 "context"
 "errors"
 "github.com/unrotten/builder"
 "time"
)
type User struct {
 Id int64 `json:"id" db:"id"`
 Username string `json:"username" db:"username"`
 Email string `json:"email" db:"email"`
 Password string `json:"password" db:"password"`
 Avatar string `json:"avatar" db:"avatar"`
 Gender string `json:"gender" db:"gender"`
 Introduce string `json:"introduce" db:"introduce"`
 State string `json:"state" db:"state"`
 Root bool `json:"root" db:"root"`
 CreatedAt time.Time `json:"createdAt" db:"created_at"`
 UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
 DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}
func GetUsers(ctx context.Context, where where) ([]User, error) {
 result := selectList(ctx, `"user"`, where)
 if !result.success {
 return nil, errors.New("获取用户列表失败")
 }
 list, ok := builder.Get(result.b, "list")
 if !ok {
 return nil, errors.New("获取用户列表失败")
 }
 users := make([]User, 0, len(list.([]interface{})))
 for _, item := range list.([]interface{}) {
 users = append(users, builder.GetStructLikeByTag(item.(builder.Builder), User{}, "db").(User))
 }
 return users, nil
}
func GetUser(ctx context.Context, where where) (User, error) {
 result := selectOne(ctx, `"user"`, where)
 if !result.success {
 return User{}, errors.New("查询用户数据失败")
 }
 return builder.GetStructLikeByTag(result.b, User{}, "db").(User), nil
}
func InsertUser(ctx context.Context, cv map[string]interface{}) (User, error) {
 id, err := idfetcher.NextID()
 if err != nil {
 return User{}, err
 }
 cv["id"] = int64(id)
 result := insertOne(ctx, `"user"`, cv)
 if !result.success {
 return User{}, errors.New("插入用户数据失败")
 }
 return builder.GetStructLikeByTag(result.b, User{}, "db").(User), nil
}
func UpdateUser(ctx context.Context, cv cv, where where) error {
 result := update(ctx, `"user"`, cv, where)
 if !result.success {
 return errors.New("更新用户数据失败")
 }
 return nil
}

这里唯一需要注意的是,我们使用builder.GetStructLikeByTag(result.b, User{}, "db").(User)方法,将CURD中获得的Builder根据指定的tag内容,转化为对应结构体。

接下来,就是继续完善其他的model。

userCount.go:

package model
import (
 "context"
 "errors"
 "github.com/unrotten/builder"
 "github.com/unrotten/sqlex"
 "time"
)
type UserCount struct {
 Uid int64 `json:"uid" db:"uid"`
 FansNum int32 `json:"fansNum" db:"fans_num"`
 FollowNum int32 `json:"followNum" db:"follow_num"`
 ArticleNum int32 `json:"articleNum" db:"article_num"`
 Words int32 `json:"words" db:"words"`
 ZanNum int32 `json:"zanNum" db:"zan_num"`
 CreatedAt time.Time `json:"createdAt" db:"created_at"`
 UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
 DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}
func GetUserCount(ctx context.Context, uid int64, columns ...string) (UserCount, error) {
 result := selectOne(ctx, "user_count", append(where{}, sqlex.Eq{"uid": uid}), columns...)
 if !result.success {
 return UserCount{}, errors.New("查询用户计数失败")
 }
 return builder.GetStructLikeByTag(result.b, UserCount{}, "db").(UserCount), nil
}
func InsertUserCount(ctx context.Context, uid int64) error {
 result := insertOne(ctx, "user_count", cv{"uid": uid})
 if !result.success {
 return errors.New("保存用户计数表失败")
 }
 return nil
}
func UpdateUserCount(ctx context.Context, uid int64, add bool, columns ...string) error {
 directSets, directSet := make([]string, 0, len(columns)), " + 1"
 if !add {
 directSet = " - 1"
 }
 for _, col := range columns {
 directSets = append(directSets, col+directSet)
 }
 if !update(ctx, "user_count", cv{}, where{sqlex.Eq{"uid": uid}}, directSets...).success {
 return errors.New("增加用户计数失败")
 }
 return nil
}

我们为了改变userCount中的计数值,定义了方法UpdateUserCount。可以通过指定加减和相应字段来实现计数值的加减。我们可以注意到了,这里在调用update的时候,传入了directSets,最终将通过update中的:

for _, set := range directSet {
 updateBuilder = updateBuilder.DirectSet(set)
}

将设置好的值构建到SQL中。DirectSet目的是构建无参数的set语句,所以并不建议暴露给从接口传入的参数,否则会有SQL注入的风险。

userFollow.go:

package model
import (
 "context"
 "errors"
 "github.com/unrotten/builder"
 "github.com/unrotten/sqlex"
 "time"
)
type UserFollow struct {
 Id int64 `json:"id" db:"id"`
 Uid int64 `json:"uid" db:"uid"`
 Fuid int64 `json:"fuid" db:"fuid"`
 CreatedAt time.Time `json:"createdAt" db:"created_at"`
 UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
 DeletedAt time.Time `json:"deletedAt" db:"deleted_at"`
}
func InsertUserFollow(ctx context.Context, uid, fuid int64) error {
 id, err := idfetcher.NextID()
 if err != nil {
 return err
 }
 if result := insertOne(ctx, "user_follow", cv{"id": int64(id), "uid": uid, "fuid": fuid}); !result.success {
 return errors.New("插入用户关注表失败")
 }
 return nil
}
func RemoveUserFollow(ctx context.Context, uid, fuid int64) error {
 if !remove(ctx, "user_follow", where{sqlex.Eq{"uid": uid, "fuid": fuid}}).success {
 return errors.New("删除用户关注失败")
 }
 return nil
}
// 获取用户关注列表
func GetUserFollowList(ctx context.Context, fuid int64) ([]int64, error) {
 result := selectList(ctx, "user_follow", where{sqlex.Eq{"fuid": fuid}}, "uid")
 if !result.success {
 return nil, errors.New("获取用户关注列表失败")
 }
 b, _ := builder.Get(result.b, "list")
 list := b.([]interface{})
 userList := make([]int64, 0, len(list))
 for _, item := range list {
 uid, _ := builder.Get(item.(builder.Builder), "uid")
 userList = append(userList, uid.(int64))
 }
 return userList, nil
}
// 获取用户粉丝列表
func GetFollowUserList(ctx context.Context, uid int64) ([]int64, error) {
 result := selectList(ctx, "user_follow", where{sqlex.Eq{"uid": uid}}, "fuid")
 if !result.success {
 return nil, errors.New("获取用户关注列表失败")
 }
 b, _ := builder.Get(result.b, "list")
 list := b.([]interface{})
 userList := make([]int64, 0, len(list))
 for _, item := range list {
 uid, _ := builder.Get(item.(builder.Builder), "fuid")
 userList = append(userList, uid.(int64))
 }
 return userList, nil
}

在这里无论是粉丝列表还是关注列表,我们都指定了获取对应的userId列表,而非UserFollow数组。这是为了便于后续dataloader的使用,以后会提到。

到这里用户相关的model就编写完了,后面真正与前端一起联调时,定还有许多更改。而其他诸如文章,评论等的model,便不再赘述。用户相关的model,已经将基本的CURD涵盖。

看完这里,我们可以发现,对于user的扩展表user_count 和 user_follow, 我们并没有在model层面去设计他们的关系,在数据的获取,新增,修改上,也都是独立的。这是因为我们所有定义的数据之间的关系,都交由GraphQL去描述了,在数据层我们反而不用多在意这些关系的实现。


作者个人博客地址:https://unrotten.org
作者微信公众号地址:
WechatIMG2.jpeg


有疑问加站长微信联系(非本文作者)

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

关注微信
1103 次点击
暂无回复
添加一条新回复 (您需要 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传

用户登录

没有账号?注册
(追記) (追記ここまで)

今日阅读排行

    加载中
(追記) (追記ここまで)

一周阅读排行

    加载中

关注我

  • 扫码关注领全套学习资料 关注微信公众号
  • 加入 QQ 群:
    • 192706294(已满)
    • 731990104(已满)
    • 798786647(已满)
    • 729884609(已满)
    • 977810755(已满)
    • 815126783(已满)
    • 812540095(已满)
    • 1006366459(已满)
    • 692541889

  • 关注微信公众号
  • 加入微信群:liuxiaoyan-s,备注入群
  • 也欢迎加入知识星球 Go粉丝们(免费)

给该专栏投稿 写篇新文章

每篇文章有总共有 5 次投稿机会

收入到我管理的专栏 新建专栏