diff --git a/has_test.go b/has_test.go index 02a4583..611f999 100644 --- a/has_test.go +++ b/has_test.go @@ -1,6 +1,7 @@ package bob_test import ( + "reflect" "testing" "github.com/aldy505/bob" @@ -19,8 +20,8 @@ func TestHas(t *testing.T) { if sql != result { t.Fatal("sql is not equal with result:", sql) } - - if len(args) != 1 { + argsResult := []interface{}{"users"} + if !reflect.DeepEqual(args, argsResult) { t.Fatal("args is not equal with argsResult:", args) } }) @@ -36,7 +37,8 @@ func TestHas(t *testing.T) { t.Fatal("sql is not equal with result:", sql) } - if len(args) != 2 { + argsResult := []interface{}{"users", "name"} + if !reflect.DeepEqual(args, argsResult) { t.Fatal("args is not equal with argsResult:", args) } }) @@ -52,7 +54,8 @@ func TestHas(t *testing.T) { t.Fatal("sql is not equal with result:", sql) } - if len(args) != 2 { + argsResult := []interface{}{"users", "name"} + if !reflect.DeepEqual(args, argsResult) { t.Fatal("args is not equal with argsResult:", args) } }) @@ -68,12 +71,13 @@ func TestHas(t *testing.T) { t.Fatal("sql is not equal with result:", sql) } - if len(args) != 2 { + argsResult := []interface{}{"users", "private"} + if !reflect.DeepEqual(args, argsResult) { t.Fatal("args is not equal with argsResult:", args) } }) - t.Run("should be able to have a different placeholder", func(t *testing.T) { + t.Run("should be able to have a different placeholder - dollar", func(t *testing.T) { sql, args, err := bob.HasTable("users").HasColumn("name").PlaceholderFormat(bob.Dollar).ToSql() if err != nil { t.Fatal(err.Error()) @@ -84,7 +88,8 @@ func TestHas(t *testing.T) { t.Fatal("sql is not equal with result:", sql) } - if len(args) != 2 { + argsResult := []interface{}{"users", "name"} + if !reflect.DeepEqual(args, argsResult) { t.Fatal("args is not equal with argsResult:", args) } }) diff --git a/upsert.go b/upsert.go index c1cec6c..2baf6b0 100644 --- a/upsert.go +++ b/upsert.go @@ -24,42 +24,49 @@ func init() { builder.Register(UpsertBuilder{}, upsertData{}) } -func (b UpsertBuilder) dialect(db int) UpsertBuilder { - return builder.Set(b, "Dialect", db).(UpsertBuilder) +func (u UpsertBuilder) dialect(db int) UpsertBuilder { + return builder.Set(u, "Dialect", db).(UpsertBuilder) } // Table sets which table to be dropped -func (b UpsertBuilder) Into(name string) UpsertBuilder { - return builder.Set(b, "Into", name).(UpsertBuilder) +func (u UpsertBuilder) Into(name string) UpsertBuilder { + return builder.Set(u, "Into", name).(UpsertBuilder) } -func (b UpsertBuilder) Columns(columns ...string) UpsertBuilder { - return builder.Extend(b, "Columns", columns).(UpsertBuilder) +func (u UpsertBuilder) Columns(columns ...string) UpsertBuilder { + return builder.Extend(u, "Columns", columns).(UpsertBuilder) } // Values sets the values in relation with the columns. // Please not that only string, int, and bool type are supported. // Inputting other types other than those might result in your SQL not working properly. -func (b UpsertBuilder) Values(values ...interface{}) UpsertBuilder { - return builder.Append(b, "Values", values).(UpsertBuilder) +func (u UpsertBuilder) Values(values ...interface{}) UpsertBuilder { + return builder.Append(u, "Values", values).(UpsertBuilder) } -func (b UpsertBuilder) Key(key ...interface{}) UpsertBuilder { - return builder.Extend(b, "Key", []interface{}{key[0], key[1]}).(UpsertBuilder) +func (u UpsertBuilder) Key(key ...interface{}) UpsertBuilder { + var value interface{} + column := key[0] + if len(key) > 1 && key[0] != nil { + value = key[1] + } else { + value = "" + } + return builder.Extend(u, "Key", []interface{}{column, value}).(UpsertBuilder) } -func (b UpsertBuilder) Replace(column interface{}, value interface{}) UpsertBuilder { - return builder.Extend(b, "Replace", []interface{}{column, value}).(UpsertBuilder) +func (u UpsertBuilder) Replace(column interface{}, value interface{}) UpsertBuilder { + return builder.Append(u, "Replace", []interface{}{column, value}).(UpsertBuilder) } // PlaceholderFormat changes the default placeholder (?) to desired placeholder. -func (b UpsertBuilder) PlaceholderFormat(f string) UpsertBuilder { - return builder.Set(b, "Placeholder", f).(UpsertBuilder) +func (u UpsertBuilder) PlaceholderFormat(f string) UpsertBuilder { + return builder.Set(u, "Placeholder", f).(UpsertBuilder) } // ToSql returns 3 variables filled out with the correct values based on bindings, etc. -func (b UpsertBuilder) ToSql() (string, []interface{}, error) { - data := builder.GetStruct(b).(upsertData) +func (u UpsertBuilder) ToSql() (string, []interface{}, error) { + data := builder.GetStruct(u).(upsertData) return data.ToSql() } @@ -121,6 +128,7 @@ func (d *upsertData) ToSql() (sqlStr string, args []interface{}, err error) { } sql.WriteString(strings.Join(values, ", ")) + sql.WriteString(" ") var replaces []string for i := 0; i < len(d.Replace); i++ { @@ -137,6 +145,11 @@ func (d *upsertData) ToSql() (sqlStr string, args []interface{}, err error) { } else if d.Dialect == Postgresql || d.Dialect == Sqlite { // INSERT INTO players (user_name, age) VALUES('steven', 32) ON CONFLICT(user_name) DO UPDATE SET age=excluded.age; + if len(d.Key) == 0 { + err = errors.New("unique key must be provided for PostgreSQL and SQLite") + return + } + sql.WriteString("ON CONFLICT ") sql.WriteString("(\""+d.Key[0].(string)+"\") ") sql.WriteString("DO UPDATE SET ") diff --git a/upsert_test.go b/upsert_test.go index 8c36537..03f77f4 100644 --- a/upsert_test.go +++ b/upsert_test.go @@ -8,30 +8,6 @@ import ( ) func TestUpsert(t *testing.T) { - t.Run("should be able to generate upsert query for postgres", func(t *testing.T) { - sql, args, err := bob. - Upsert("users", bob.Postgresql). - Columns("name", "email"). - Values("John Doe", "john@doe.com"). - Key("email", nil). - Replace("name", "John Does"). - PlaceholderFormat(bob.Dollar). - ToSql() - if err != nil { - t.Error(err) - } - - desiredSql := "INSERT INTO \"users\" (\"name\", \"email\") VALUES ($1, $2) ON CONFLICT (\"email\") DO UPDATE SET \"name\" = $3;" - desiredArgs := []interface{}{"John Doe", "john@doe.com", "John Does"} - - if sql != desiredSql { - t.Error("sql is not the same as result: ", sql) - } - if reflect.DeepEqual(args, desiredArgs) { - t.Error("args is not the same as result: ", args) - } - }) - t.Run("should be able to generate upsert query for mysql", func(t *testing.T) { sql, args, err := bob. Upsert("users", bob.Mysql). @@ -49,19 +25,19 @@ func TestUpsert(t *testing.T) { if sql != desiredSql { t.Error("sql is not the same as result: ", sql) } - if reflect.DeepEqual(args, desiredArgs) { + if !reflect.DeepEqual(args, desiredArgs) { t.Error("args is not the same as result: ", args) } }) - t.Run("should be able to generate upsert query for sqlite", func(t *testing.T) { + t.Run("should be able to generate upsert query for postgres", func(t *testing.T) { sql, args, err := bob. - Upsert("users", bob.Sqlite). + Upsert("users", bob.Postgresql). Columns("name", "email"). Values("John Doe", "john@doe.com"). - Key("email", nil). + Key("email"). Replace("name", "John Does"). - PlaceholderFormat(bob.Question). + PlaceholderFormat(bob.Dollar). ToSql() if err != nil { t.Error(err) @@ -73,7 +49,31 @@ func TestUpsert(t *testing.T) { if sql != desiredSql { t.Error("sql is not the same as result: ", sql) } - if reflect.DeepEqual(args, desiredArgs) { + if !reflect.DeepEqual(args, desiredArgs) { + t.Error("args is not the same as result: ", args) + } + }) + + t.Run("should be able to generate upsert query for sqlite", func(t *testing.T) { + sql, args, err := bob. + Upsert("users", bob.Sqlite). + Columns("name", "email"). + Values("John Doe", "john@doe.com"). + Key("email"). + Replace("name", "John Does"). + PlaceholderFormat(bob.Question). + ToSql() + if err != nil { + t.Error(err) + } + + desiredSql := "INSERT INTO \"users\" (\"name\", \"email\") VALUES (?, ?) ON CONFLICT (\"email\") DO UPDATE SET \"name\" = ?;" + desiredArgs := []interface{}{"John Doe", "john@doe.com", "John Does"} + + if sql != desiredSql { + t.Error("sql is not the same as result: ", sql) + } + if !reflect.DeepEqual(args, desiredArgs) { t.Error("args is not the same as result: ", args) } }) @@ -91,13 +91,13 @@ func TestUpsert(t *testing.T) { t.Error(err) } - desiredSql := "IF NOT EXISTS (SELECT * FROM \"users\" WHERE \"email\" = @p1) INSERT INTO \"users\" (\"name\", \"email\") VALUES (@p2, @p3) ELSE UPDATE SET \"name\" = @p4 WHERE \"users\" = @p5;" - desiredArgs := []interface{}{"john@doe.com", "John Doe", "john@doe.com", "John Does"} + desiredSql := "IF NOT EXISTS (SELECT * FROM \"users\" WHERE \"email\" = @p1) INSERT INTO \"users\" (\"name\", \"email\") VALUES (@p2, @p3) ELSE UPDATE \"users\" SET \"name\" = @p4 WHERE \"email\" = @p5;" + desiredArgs := []interface{}{"john@doe.com", "John Doe", "john@doe.com", "John Does", "john@doe.com"} if sql != desiredSql { t.Error("sql is not the same as result: ", sql) } - if reflect.DeepEqual(args, desiredArgs) { + if !reflect.DeepEqual(args, desiredArgs) { t.Error("args is not the same as result: ", args) } }) diff --git a/util/arguments.go b/util/arguments.go index 96951db..30d7119 100644 --- a/util/arguments.go +++ b/util/arguments.go @@ -2,7 +2,7 @@ package util // createArgs should create an argument []interface{} for SQL query // I'm using the idiot approach for creating args -func CreateArgs(keys ...string) []interface{} { +func CreateArgs(keys ...interface{}) []interface{} { var args []interface{} for _, v := range keys { if v == "" {