diff --git a/db/alter.go b/db/alter.go index 0a4ffdd..6028700 100644 --- a/db/alter.go +++ b/db/alter.go @@ -1,52 +1,52 @@ package db import ( "fmt" "strings" ) type AlterTableSqlBuilder struct { Dialect DialectType Name string Changes []string } func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { - if colVal, err := col.String(); err == nil { + if colVal, err := col.CreateSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) } return b } func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { - if colVal, err := col.String(); err == nil { - b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + if colActions, err := col.AlterSQL(b.Dialect, name); err == nil { + b.Changes = append(b.Changes, colActions...) } return b } func (b *AlterTableSqlBuilder) AddUniqueConstraint(name string, columns ...string) *AlterTableSqlBuilder { b.Changes = append(b.Changes, fmt.Sprintf("ADD CONSTRAINT %s UNIQUE (%s)", name, strings.Join(columns, ", "))) return b } func (b *AlterTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("ALTER TABLE ") str.WriteString(b.Name) str.WriteString(" ") if len(b.Changes) == 0 { return "", fmt.Errorf("no changes provide for table: %s", b.Name) } changeCount := len(b.Changes) for i, thing := range b.Changes { str.WriteString(thing) if i < changeCount-1 { str.WriteString(", ") } } return str.String(), nil } diff --git a/db/alter_test.go b/db/alter_test.go index 4bd58ac..4d47821 100644 --- a/db/alter_test.go +++ b/db/alter_test.go @@ -1,56 +1,71 @@ package db import "testing" func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { type fields struct { Dialect DialectType Name string Changes []string } tests := []struct { name string builder *AlterTableSqlBuilder want string wantErr bool }{ { name: "MySQL add int", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), - want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", + AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})), + want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL", wantErr: false, }, { name: "MySQL add string", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), + AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})), want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", wantErr: false, }, - { name: "MySQL add int and string", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). - AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), + AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})). + AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})), want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", wantErr: false, }, + { + name: "MySQL change to string", + builder: DialectMySQL. + AlterTable("the_table"). + ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})), + want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL", + wantErr: false, + }, + { + name: "PostgreSQL change to int", + builder: DialectMySQL. + AlterTable("the_table"). + ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})), + want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.builder.ToSQL() if (err != nil) != tt.wantErr { t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("ToSQL() got = %v, want %v", got, tt.want) } }) } } diff --git a/db/builder.go b/db/builder.go new file mode 100644 index 0000000..d0f4fe4 --- /dev/null +++ b/db/builder.go @@ -0,0 +1,15 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +type SQLBuilder interface { + ToSQL() (string, error) +} diff --git a/db/column.go b/db/column.go new file mode 100644 index 0000000..5bf64e4 --- /dev/null +++ b/db/column.go @@ -0,0 +1,328 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +import ( + "fmt" + "strings" +) + +type Column struct { + Name string + Type ColumnType + Nullable bool + PrimaryKey bool +} + +func NullableColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: true, + PrimaryKey: false, + } +} + +func NonNullableColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: false, + PrimaryKey: false, + } +} + +func PrimaryKeyColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: false, + PrimaryKey: true, + } +} + +type ColumnType interface { + Name(DialectType) (string, error) + Default(DialectType) (string, error) +} + +type ColumnTypeInt struct { + IsSigned bool + MaxBytes int + MaxDigits int + HasDefault bool + DefaultVal int +} + +type ColumnTypeString struct { + IsFixedLength bool + MaxChars int + HasDefault bool + DefaultVal string +} + +type ColumnDefault int + +type ColumnTypeBool struct { + DefaultVal ColumnDefault +} + +const ( + NoDefault ColumnDefault = iota + DefaultFalse ColumnDefault = iota + DefaultTrue ColumnDefault = iota + DefaultNow ColumnDefault = iota +) + +type ColumnTypeDateTime struct { + DefaultVal ColumnDefault +} + +func (intCol ColumnTypeInt) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "INTEGER", nil + + case DialectMySQL, DialectPostgreSQL: + var colName string + switch intCol.MaxBytes { + case 1: + if d == DialectMySQL { + colName = "TINYINT" + } else { + colName = "SMALLINT" + } + case 2: + colName = "SMALLINT" + case 3: + if d == DialectMySQL { + colName = "MEDIUMINT" + } else { + colName = "INTEGER" + } + case 4: + colName = "INTEGER" + default: + colName = "BIGINT" + } + if d == DialectMySQL { + if intCol.MaxDigits > 0 { + colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) + } + if !intCol.IsSigned { + colName += " UNSIGNED" + } + } + return colName, nil + + default: + return "", fmt.Errorf("dialect %d does not support integer columns", d) + } +} + +func (intCol ColumnTypeInt) Default(d DialectType) (string, error) { + if intCol.HasDefault { + return fmt.Sprintf("%d", intCol.DefaultVal), nil + } + return "", nil +} + +func (strCol ColumnTypeString) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "TEXT", nil + + case DialectMySQL, DialectPostgreSQL: + if strCol.IsFixedLength { + if strCol.MaxChars > 0 { + return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil + } + return "CHAR", nil + } + + if strCol.MaxChars <= 0 { + return "TEXT", nil + } + if strCol.MaxChars < (1 << 16) { + return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil + } + return "TEXT", nil + + default: + return "", fmt.Errorf("dialect %d does not support string columns", d) + } +} + +func (strCol ColumnTypeString) Default(d DialectType) (string, error) { + if strCol.HasDefault { + return EscapeSimple.SQLEscape(d, strCol.DefaultVal) + } + return "", nil +} + +func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "INTEGER", nil + case DialectMySQL, DialectPostgreSQL: + return "BOOL", nil + default: + return "", fmt.Errorf("boolean column type not supported for dialect %d", d) + } +} + +func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) { + switch boolCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultFalse: + return "0", nil + case DefaultTrue: + return "1", nil + default: + return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) + } +} + +func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + return "DATETIME", nil + case DialectPostgreSQL: + return "TIMESTAMP", nil + default: + return "", fmt.Errorf("datetime column type not supported for dialect %d", d) + } +} + +func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + switch dateTimeCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultNow: + switch d { + case DialectSQLite, DialectPostgreSQL: + return "CURRENT_TIMESTAMP", nil + case DialectMySQL: + return "NOW()", nil + } + } + return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d) + default: + return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d) + } +} + +func (c *Column) SetName(name string) *Column { + c.Name = name + return c +} + +func (c *Column) SetNullable(nullable bool) *Column { + c.Nullable = nullable + return c +} + +func (c *Column) SetPrimaryKey(pk bool) *Column { + c.PrimaryKey = pk + return c +} + +func (c *Column) SetType(t ColumnType) *Column { + c.Type = t + return c +} + +func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) { + var actions []string = make([]string, 0) + + switch d { + // MySQL does all modifications at once + case DialectMySQL: + sql, err := c.CreateSQL(d) + if err != nil { + return make([]string, 0), err + } + actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql)) + + // PostgreSQL does modifications piece by piece + case DialectPostgreSQL: + if oldName != c.Name { + actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name)) + } + + typeStr, err := c.Type.Name(d) + if err != nil { + return make([]string, 0), err + } + + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr)) + var nullAction string + if c.Nullable { + nullAction = "DROP" + } else { + nullAction = "SET" + } + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction)) + + defaultStr, err := c.Type.Default(d) + if err != nil { + return make([]string, 0), err + } + if len(defaultStr) > 0 { + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr)) + } + + if c.PrimaryKey { + actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name)) + } + + default: + return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d) + } + + return actions, nil +} + +func (c *Column) CreateSQL(d DialectType) (string, error) { + var str strings.Builder + + str.WriteString(c.Name) + + str.WriteString(" ") + typeStr, err := c.Type.Name(d) + if err != nil { + return "", err + } + + str.WriteString(typeStr) + + if !c.Nullable { + str.WriteString(" NOT NULL") + } + + defaultStr, err := c.Type.Default(d) + if err != nil { + return "", err + } + if len(defaultStr) > 0 { + str.WriteString(" DEFAULT ") + str.WriteString(defaultStr) + } + + if c.PrimaryKey { + str.WriteString(" PRIMARY KEY") + } + + return str.String(), nil +} diff --git a/db/column_test.go b/db/column_test.go new file mode 100644 index 0000000..175ec3a --- /dev/null +++ b/db/column_test.go @@ -0,0 +1,151 @@ +package db + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestColumnType_Name(t *testing.T) { + tests := []struct { + name string + ty ColumnType + d DialectType + want string + wantErr bool + }{ + {"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false}, + {"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false}, + {"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false}, + {"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false}, + + {"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false}, + {"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false}, + {"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false}, + {"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false}, + {"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false}, + {"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false}, + {"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false}, + {"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false}, + {"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false}, + {"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false}, + {"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false}, + {"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false}, + {"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false}, + {"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false}, + {"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false}, + {"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ty.Name(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Name() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumnType_Default(t *testing.T) { + tests := []struct { + name string + ty ColumnType + d DialectType + want string + wantErr bool + }{ + {"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false}, + {"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false}, + {"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false}, + {"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false}, + {"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false}, + {"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false}, + {"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false}, + {"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false}, + {"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false}, + {"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false}, + + {"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false}, + {"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false}, + {"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false}, + {"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ty.Default(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Default() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumn_CreateSQL(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Nullable bool + Type ColumnType + PrimaryKey bool + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + {"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false}, + {"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false}, + {"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false}, + {"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false}, + {"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false}, + {"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false}, + {"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false}, + + {"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false}, + {"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false}, + {"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false}, + {"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false}, + {"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, + {"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false}, + {"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false}, + {"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false}, + {"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false}, + {"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false}, + {"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false}, + {"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false}, + {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false}, + {"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false}, + {"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false}, + {"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false}, + {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Column{ + Name: tt.fields.Name, + Nullable: tt.fields.Nullable, + Type: tt.fields.Type, + PrimaryKey: tt.fields.PrimaryKey, + } + if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want { + if (err != nil) != tt.wantErr { + t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("String() got = %v, want %v", got, tt.want) + } + } + }) + } +} diff --git a/db/create.go b/db/create.go index 1e9e679..af0ea2e 100644 --- a/db/create.go +++ b/db/create.go @@ -1,263 +1,86 @@ /* * Copyright © 2019-2020 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package db import ( "fmt" "strings" ) -type ColumnType int - -type OptionalInt struct { - Set bool - Value int -} - -type OptionalString struct { - Set bool - Value string -} - -type SQLBuilder interface { - ToSQL() (string, error) -} - -type Column struct { - Dialect DialectType - Name string - Nullable bool - Default OptionalString - Type ColumnType - Size OptionalInt - PrimaryKey bool -} - type CreateTableSqlBuilder struct { Dialect DialectType Name string IfNotExists bool ColumnOrder []string Columns map[string]*Column Constraints []string } -const ( - ColumnTypeBool ColumnType = iota - ColumnTypeSmallInt ColumnType = iota - ColumnTypeInteger ColumnType = iota - ColumnTypeChar ColumnType = iota - ColumnTypeVarChar ColumnType = iota - ColumnTypeText ColumnType = iota - ColumnTypeDateTime ColumnType = iota -) - -var _ SQLBuilder = &CreateTableSqlBuilder{} - -var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} -var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} - -func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { - if dialect != DialectMySQL && dialect != DialectSQLite { - return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) - } - switch d { - case ColumnTypeSmallInt: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "SMALLINT" + mod, nil - } - case ColumnTypeInteger: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "INT" + mod, nil - } - case ColumnTypeChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "CHAR" + mod, nil - } - case ColumnTypeVarChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "VARCHAR" + mod, nil - } - case ColumnTypeBool: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - return "TINYINT(1)", nil - } - case ColumnTypeDateTime: - return "DATETIME", nil - case ColumnTypeText: - return "TEXT", nil - } - return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) -} - -func (c *Column) SetName(name string) *Column { - c.Name = name - return c -} - -func (c *Column) SetNullable(nullable bool) *Column { - c.Nullable = nullable - return c -} - -func (c *Column) SetPrimaryKey(pk bool) *Column { - c.PrimaryKey = pk - return c -} - -func (c *Column) SetDefault(value string) *Column { - c.Default = OptionalString{Set: true, Value: value} - return c -} - -func (c *Column) SetDefaultCurrentTimestamp() *Column { - def := "NOW()" - if c.Dialect == DialectSQLite { - def = "CURRENT_TIMESTAMP" - } - c.Default = OptionalString{Set: true, Value: def} - return c -} - -func (c *Column) SetType(t ColumnType) *Column { - c.Type = t - return c -} - -func (c *Column) SetSize(size int) *Column { - c.Size = OptionalInt{Set: true, Value: size} - return c -} - -func (c *Column) String() (string, error) { - var str strings.Builder - - str.WriteString(c.Name) - - str.WriteString(" ") - typeStr, err := c.Type.Format(c.Dialect, c.Size) - if err != nil { - return "", err - } - - str.WriteString(typeStr) - - if !c.Nullable { - str.WriteString(" NOT NULL") - } - - if c.Default.Set { - str.WriteString(" DEFAULT ") - val := c.Default.Value - if val == "" { - val = "''" - } - str.WriteString(val) - } - - if c.PrimaryKey { - str.WriteString(" PRIMARY KEY") - } - - return str.String(), nil -} - func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) } b.Columns[column.Name] = column b.ColumnOrder = append(b.ColumnOrder, column.Name) return b } func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder { for _, column := range columns { if _, ok := b.Columns[column]; !ok { // This fails silently. return b } } b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ","))) return b } func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder { b.IfNotExists = ine return b } func (b *CreateTableSqlBuilder) ToSQL() (string, error) { var str strings.Builder str.WriteString("CREATE TABLE ") if b.IfNotExists { str.WriteString("IF NOT EXISTS ") } str.WriteString(b.Name) var things []string for _, columnName := range b.ColumnOrder { column, ok := b.Columns[columnName] if !ok { return "", fmt.Errorf("column not found: %s", columnName) } - columnStr, err := column.String() + columnStr, err := column.CreateSQL(b.Dialect) if err != nil { return "", err } things = append(things, columnStr) } things = append(things, b.Constraints...) if thingLen := len(things); thingLen > 0 { str.WriteString(" ( ") for i, thing := range things { str.WriteString(thing) if i < thingLen-1 { str.WriteString(", ") } } str.WriteString(" )") } return str.String(), nil } diff --git a/db/create_test.go b/db/create_test.go index 369d5c1..09efd18 100644 --- a/db/create_test.go +++ b/db/create_test.go @@ -1,146 +1,20 @@ package db import ( "github.com/stretchr/testify/assert" "testing" ) -func TestDialect_Column(t *testing.T) { - c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize) - assert.Equal(t, DialectSQLite, c1.Dialect) - c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize) - assert.Equal(t, DialectMySQL, c2.Dialect) -} - -func TestColumnType_Format(t *testing.T) { - type args struct { - dialect DialectType - size OptionalInt - } - tests := []struct { - name string - d ColumnType - args args - want string - wantErr bool - }{ - {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false}, - - {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false}, - {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false}, - {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false}, - {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false}, - {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false}, - {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false}, - {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false}, - {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false}, - {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false}, - {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false}, - {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false}, - - {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true}, - {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.d.Format(tt.args.dialect, tt.args.size) - if (err != nil) != tt.wantErr { - t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Format() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestColumn_Build(t *testing.T) { - type fields struct { - Dialect DialectType - Name string - Nullable bool - Default OptionalString - Type ColumnType - Size OptionalInt - PrimaryKey bool - } - tests := []struct { - name string - fields fields - want string - wantErr bool - }{ - {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false}, - {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, - {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false}, - {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, - {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, - - {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false}, - {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false}, - {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, - {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false}, - {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false}, - {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false}, - {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false}, - {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false}, - {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false}, - {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false}, - {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, - {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, - {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &Column{ - Dialect: tt.fields.Dialect, - Name: tt.fields.Name, - Nullable: tt.fields.Nullable, - Default: tt.fields.Default, - Type: tt.fields.Type, - Size: tt.fields.Size, - PrimaryKey: tt.fields.PrimaryKey, - } - if got, err := c.String(); got != tt.want { - if (err != nil) != tt.wantErr { - t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("String() got = %v, want %v", got, tt.want) - } - } - }) - } -} - func TestCreateTableSqlBuilder_ToSQL(t *testing.T) { sql, err := DialectMySQL. Table("foo"). SetIfNotExists(true). - Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)). - Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)). - Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")). + Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})). + Column(NonNullableColumn("baz", ColumnTypeString{})). + Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})). UniqueConstraint("bar"). UniqueConstraint("bar", "baz"). ToSQL() assert.NoError(t, err) assert.Equal(t, "CREATE TABLE IF NOT EXISTS foo ( bar INT NOT NULL PRIMARY KEY, baz TEXT NOT NULL, qux DATETIME NOT NULL DEFAULT NOW(), UNIQUE(bar), UNIQUE(bar,baz) )", sql) } diff --git a/db/dialect.go b/db/dialect.go index 4251465..3e2b90b 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -1,76 +1,51 @@ package db import "fmt" type DialectType int const ( - DialectSQLite DialectType = iota - DialectMySQL DialectType = iota + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota + DialectPostgreSQL DialectType = iota ) -func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { +func (d DialectType) IsKnown() bool { switch d { - case DialectSQLite: - return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} - case DialectMySQL: - return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + case DialectSQLite, DialectMySQL, DialectPostgreSQL: + return true default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) + return false } } -func (d DialectType) Table(name string) *CreateTableSqlBuilder { - switch d { - case DialectSQLite: - return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: +func (d DialectType) AssertKnown() { + if !d.IsKnown() { panic(fmt.Sprintf("unexpected dialect: %d", d)) } } +func (d DialectType) Table(name string) *CreateTableSqlBuilder { + d.AssertKnown() + return &CreateTableSqlBuilder{Dialect: d, Name: name} +} + func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { - switch d { - case DialectSQLite: - return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &AlterTableSqlBuilder{Dialect: d, Name: name} } func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { - switch d { - case DialectSQLite: - return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} - case DialectMySQL: - return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns} } func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { - switch d { - case DialectSQLite: - return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} - case DialectMySQL: - return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns} } func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { - switch d { - case DialectSQLite: - return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} - case DialectMySQL: - return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table} } diff --git a/db/escape.go b/db/escape.go new file mode 100644 index 0000000..53b8ef3 --- /dev/null +++ b/db/escape.go @@ -0,0 +1,63 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +import ( + "strings" +) + +type EscapeContext int + +const ( + EscapeSimple EscapeContext = iota +) + +func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) { + builder := strings.Builder{} + switch d { + case DialectSQLite: + builder.WriteRune('\'') + for _, c := range s { + if c == '\'' { + builder.WriteString("''") + } else { + builder.WriteRune(c) + } + } + builder.WriteRune('\'') + case DialectMySQL: + builder.WriteRune('\'') + for _, c := range s { + switch c { + case 0: + builder.WriteString("\\0") + case '\'': + builder.WriteString("\\'") + case '"': + builder.WriteString("\\\"") + case '\b': + builder.WriteString("\\b") + case '\n': + builder.WriteString("\\n") + case '\r': + builder.WriteString("\\r") + case '\t': + builder.WriteString("\\t") + case '\\': + builder.WriteString("\\\\") + default: + builder.WriteRune(c) + } + } + builder.WriteRune('\'') + } + return builder.String(), nil +} diff --git a/migrations/v4.go b/migrations/v4.go index 4ae267d..6d0c9f2 100644 --- a/migrations/v4.go +++ b/migrations/v4.go @@ -1,54 +1,54 @@ /* * Copyright © 2019-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauth(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { createTableUsersOauth, err := dialect. Table("oauth_users"). SetIfNotExists(false). - Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). - Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})). + Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})). ToSQL() if err != nil { return err } createTableOauthClientState, err := dialect. Table("oauth_client_states"). SetIfNotExists(false). - Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). - Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). - Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()). + Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})). + Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})). + Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})). UniqueConstraint("state"). ToSQL() if err != nil { return err } for _, table := range []string{createTableUsersOauth, createTableOauthClientState} { if _, err := tx.ExecContext(ctx, table); err != nil { return err } } return nil }) } diff --git a/migrations/v5.go b/migrations/v5.go index db18fa1..4508b18 100644 --- a/migrations/v5.go +++ b/migrations/v5.go @@ -1,88 +1,105 @@ /* * Copyright © 2019-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauthSlack(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "provider", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 24, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "client_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 128, + HasDefault: true, + DefaultVal: "", + }, + )), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "provider", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 24, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "client_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 128, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "access_token", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 512, + HasDefault: true, + DefaultVal: "", + })), dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"), } if dialect != wf_db.DialectSQLite { // This updates the length of the `remote_user_id` column. It isn't needed for SQLite databases. builders = append(builders, dialect. AlterTable("oauth_users"). ChangeColumn("remote_user_id", - dialect. - Column( + wf_db. + NonNullableColumn( "remote_user_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}))) + wf_db.ColumnTypeString{ + MaxChars: 128, + }))) } for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) } diff --git a/migrations/v7.go b/migrations/v7.go index 2056aa0..a9af405 100644 --- a/migrations/v7.go +++ b/migrations/v7.go @@ -1,46 +1,48 @@ /* * Copyright © 2020-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauthAttach(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NullableColumn( "attach_user_id", - wf_db.ColumnTypeInteger, - wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)), + wf_db.ColumnTypeInt{ + MaxBytes: 4, + MaxDigits: 24, + })), } for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) } diff --git a/migrations/v8.go b/migrations/v8.go index 36001af..ded61c9 100644 --- a/migrations/v8.go +++ b/migrations/v8.go @@ -1,45 +1,45 @@ /* * Copyright © 2020-2021 Musing Studio LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package migrations import ( "context" "database/sql" wf_db "github.com/writefreely/writefreely/db" ) func oauthInvites(db *datastore) error { dialect := wf_db.DialectMySQL if db.driverName == driverSQLite { dialect = wf_db.DialectSQLite } return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{ - Set: true, - Value: 6, - }).SetNullable(true)), + AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{ + IsFixedLength: true, + MaxChars: 6, + })), } for _, builder := range builders { query, err := builder.ToSQL() if err != nil { return err } if _, err := tx.ExecContext(ctx, query); err != nil { return err } } return nil }) }