From b0057da7bff55caf0d2c9c2155913954350773b4 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 3 Mar 2024 17:17:12 -0800 Subject: [PATCH 1/3] fix(sqlite): Correctly skip unknown statements Without updating loc, the unknown statement would be included in the text of the next query. --- .../testdata/sqlite_skip_todo/db/db.go | 31 +++++++++ .../testdata/sqlite_skip_todo/db/models.go | 13 ++++ .../testdata/sqlite_skip_todo/db/query.sql.go | 66 +++++++++++++++++++ .../testdata/sqlite_skip_todo/query.sql | 16 +++++ .../testdata/sqlite_skip_todo/schema.sql | 3 + .../testdata/sqlite_skip_todo/sqlc.json | 16 +++++ internal/engine/sqlite/parse.go | 1 + 7 files changed, 146 insertions(+) create mode 100644 internal/endtoend/testdata/sqlite_skip_todo/db/db.go create mode 100644 internal/endtoend/testdata/sqlite_skip_todo/db/models.go create mode 100644 internal/endtoend/testdata/sqlite_skip_todo/db/query.sql.go create mode 100644 internal/endtoend/testdata/sqlite_skip_todo/query.sql create mode 100644 internal/endtoend/testdata/sqlite_skip_todo/schema.sql create mode 100644 internal/endtoend/testdata/sqlite_skip_todo/sqlc.json diff --git a/internal/endtoend/testdata/sqlite_skip_todo/db/db.go b/internal/endtoend/testdata/sqlite_skip_todo/db/db.go new file mode 100644 index 0000000000..bdb151c184 --- /dev/null +++ b/internal/endtoend/testdata/sqlite_skip_todo/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/sqlite_skip_todo/db/models.go b/internal/endtoend/testdata/sqlite_skip_todo/db/models.go new file mode 100644 index 0000000000..a1065e0b7e --- /dev/null +++ b/internal/endtoend/testdata/sqlite_skip_todo/db/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package db + +import ( + "database/sql" +) + +type Foo struct { + Bar sql.NullString +} diff --git a/internal/endtoend/testdata/sqlite_skip_todo/db/query.sql.go b/internal/endtoend/testdata/sqlite_skip_todo/db/query.sql.go new file mode 100644 index 0000000000..de32605668 --- /dev/null +++ b/internal/endtoend/testdata/sqlite_skip_todo/db/query.sql.go @@ -0,0 +1,66 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: query.sql + +package db + +import ( + "context" + "database/sql" +) + +const getFoo = `-- name: GetFoo :many +SELECT bar FROM foo +WHERE bar = ? +` + +func (q *Queries) GetFoo(ctx context.Context, bar sql.NullString) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, getFoo, bar) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var bar sql.NullString + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listFoo = `-- name: ListFoo :many +SELECT bar FROM foo +` + +func (q *Queries) ListFoo(ctx context.Context) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, listFoo) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var bar sql.NullString + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/sqlite_skip_todo/query.sql b/internal/endtoend/testdata/sqlite_skip_todo/query.sql new file mode 100644 index 0000000000..e51c45c8c1 --- /dev/null +++ b/internal/endtoend/testdata/sqlite_skip_todo/query.sql @@ -0,0 +1,16 @@ +-- name: PragmaForeignKeysEnable :exec +PRAGMA foreign_keys = 1; + +-- name: ListFoo :many +SELECT * FROM foo; + +-- name: PragmaForeignKeysDisable :exec +PRAGMA foreign_keys = 0; + +-- name: PragmaForeignKeysGet :one +PRAGMA foreign_keys; + +-- name: GetFoo :many +SELECT * FROM foo +WHERE bar = ?; + diff --git a/internal/endtoend/testdata/sqlite_skip_todo/schema.sql b/internal/endtoend/testdata/sqlite_skip_todo/schema.sql new file mode 100644 index 0000000000..010d09e16f --- /dev/null +++ b/internal/endtoend/testdata/sqlite_skip_todo/schema.sql @@ -0,0 +1,3 @@ +CREATE TABLE foo ( + bar text +); diff --git a/internal/endtoend/testdata/sqlite_skip_todo/sqlc.json b/internal/endtoend/testdata/sqlite_skip_todo/sqlc.json new file mode 100644 index 0000000000..cbd787d930 --- /dev/null +++ b/internal/endtoend/testdata/sqlite_skip_todo/sqlc.json @@ -0,0 +1,16 @@ +{ + "version": "2", + "sql": [ + { + "engine": "sqlite", + "queries": "query.sql", + "schema": "schema.sql", + "gen": { + "go": { + "package": "db", + "out": "db" + } + } + } + ] +} diff --git a/internal/engine/sqlite/parse.go b/internal/engine/sqlite/parse.go index 6da7b87112..13425b156e 100644 --- a/internal/engine/sqlite/parse.go +++ b/internal/engine/sqlite/parse.go @@ -69,6 +69,7 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) { converter := &cc{} out := converter.convert(stmt) if _, ok := out.(*ast.TODO); ok { + loc = stmt.GetStop().GetStop() + 2 continue } len := (stmt.GetStop().GetStop() + 1) - loc From 1849638ed173528e58deb90504229f48b71e7d10 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 3 Mar 2024 17:23:24 -0800 Subject: [PATCH 2/3] Pick sqlc-gen-typescript to a release --- .github/workflows/ci-typescript.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci-typescript.yml b/.github/workflows/ci-typescript.yml index 04a1b7b277..f4b94e6c10 100644 --- a/.github/workflows/ci-typescript.yml +++ b/.github/workflows/ci-typescript.yml @@ -19,5 +19,7 @@ jobs: with: repository: sqlc-dev/sqlc-gen-typescript path: typescript + # v0.1.3 + ref: daaf539092421adc15f6c3164279a3470716b560 - run: sqlc diff working-directory: typescript/examples From 8ff9cd7c9eb37a68331850fc3358c4149cd6ab81 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 3 Mar 2024 20:18:42 -0800 Subject: [PATCH 3/3] fix(opts): Validate SQL package and driver options --- internal/codegen/golang/driver.go | 46 ++----- internal/codegen/golang/gen.go | 8 +- internal/codegen/golang/imports.go | 16 +-- internal/codegen/golang/opts/enum.go | 64 ++++++++++ internal/codegen/golang/opts/options.go | 12 ++ internal/codegen/golang/postgresql_type.go | 118 +++++++++--------- internal/codegen/golang/query.go | 3 +- .../golang_invalid_sql_driver/db/db.go | 31 +++++ .../golang_invalid_sql_driver/db/models.go | 13 ++ .../golang_invalid_sql_driver/db/query.sql.go | 38 ++++++ .../golang_invalid_sql_driver/query.sql | 2 + .../golang_invalid_sql_driver/schema.sql | 3 + .../golang_invalid_sql_driver/sqlc.json | 16 +++ .../golang_invalid_sql_driver/stderr.txt | 2 + .../golang_invalid_sql_package/db/db.go | 31 +++++ .../golang_invalid_sql_package/db/models.go | 13 ++ .../db/query.sql.go | 38 ++++++ .../golang_invalid_sql_package/query.sql | 2 + .../golang_invalid_sql_package/schema.sql | 3 + .../golang_invalid_sql_package/sqlc.json | 16 +++ .../golang_invalid_sql_package/stderr.txt | 2 + 21 files changed, 366 insertions(+), 111 deletions(-) create mode 100644 internal/codegen/golang/opts/enum.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/db/db.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/db/models.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/db/query.sql.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/query.sql create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/schema.sql create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/sqlc.json create mode 100644 internal/endtoend/testdata/golang_invalid_sql_driver/stderr.txt create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/db/db.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/db/models.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/db/query.sql.go create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/query.sql create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/schema.sql create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/sqlc.json create mode 100644 internal/endtoend/testdata/golang_invalid_sql_package/stderr.txt diff --git a/internal/codegen/golang/driver.go b/internal/codegen/golang/driver.go index 7ef723b55e..5e3a533dcc 100644 --- a/internal/codegen/golang/driver.go +++ b/internal/codegen/golang/driver.go @@ -1,46 +1,14 @@ package golang -type SQLDriver string +import "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts" -const ( - SQLPackagePGXV4 string = "pgx/v4" - SQLPackagePGXV5 string = "pgx/v5" - SQLPackageStandard string = "database/sql" -) - -const ( - SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4" - SQLDriverPGXV5 = "github.com/jackc/pgx/v5" - SQLDriverLibPQ = "github.com/lib/pq" - SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql" -) - -func parseDriver(sqlPackage string) SQLDriver { +func parseDriver(sqlPackage string) opts.SQLDriver { switch sqlPackage { - case SQLPackagePGXV4: - return SQLDriverPGXV4 - case SQLPackagePGXV5: - return SQLDriverPGXV5 - default: - return SQLDriverLibPQ - } -} - -func (d SQLDriver) IsPGX() bool { - return d == SQLDriverPGXV4 || d == SQLDriverPGXV5 -} - -func (d SQLDriver) IsGoSQLDriverMySQL() bool { - return d == SQLDriverGoSQLDriverMySQL -} - -func (d SQLDriver) Package() string { - switch d { - case SQLDriverPGXV4: - return SQLPackagePGXV4 - case SQLDriverPGXV5: - return SQLPackagePGXV5 + case opts.SQLPackagePGXV4: + return opts.SQLDriverPGXV4 + case opts.SQLPackagePGXV5: + return opts.SQLDriverPGXV5 default: - return SQLPackageStandard + return opts.SQLDriverLibPQ } } diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index ebd3cf2efc..5b7977f500 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -19,7 +19,7 @@ import ( type tmplCtx struct { Q string Package string - SQLDriver SQLDriver + SQLDriver opts.SQLDriver Enums []Enum Structs []Struct GoQueries []Query @@ -189,15 +189,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, OmitSqlcVersion: options.OmitSqlcVersion, } - if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL { + if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL { return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql") } - if tctx.UsesCopyFrom && options.SqlDriver == SQLDriverGoSQLDriverMySQL { + if tctx.UsesCopyFrom && options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL { if err := checkNoTimesForMySQLCopyFrom(queries); err != nil { return nil, err } - tctx.SQLDriver = SQLDriverGoSQLDriverMySQL + tctx.SQLDriver = opts.SQLDriverGoSQLDriverMySQL } if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() { diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 613d597776..9e7819e4b1 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -126,10 +126,10 @@ func (i *importer) dbImports() fileImports { sqlpkg := parseDriver(i.Options.SqlPackage) switch sqlpkg { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"}) pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v4"}) - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}) pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"}) default: @@ -172,9 +172,9 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool for _, q := range queries { if q.Cmd == metadata.CmdExecResult { switch sqlpkg { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: pkg[ImportSpec{Path: "github.com/jackc/pgconn"}] = struct{}{} - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{} default: std["database/sql"] = struct{}{} @@ -189,7 +189,7 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool } if uses("pgtype.") { - if sqlpkg == SQLDriverPGXV5 { + if sqlpkg == opts.SQLDriverPGXV5 { pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgtype"}] = struct{}{} } else { pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{} @@ -429,7 +429,7 @@ func (i *importer) copyfromImports() fileImports { }) std["context"] = struct{}{} - if i.Options.SqlDriver == SQLDriverGoSQLDriverMySQL { + if i.Options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL { std["io"] = struct{}{} std["fmt"] = struct{}{} std["sync/atomic"] = struct{}{} @@ -481,9 +481,9 @@ func (i *importer) batchImports() fileImports { std["errors"] = struct{}{} sqlpkg := parseDriver(i.Options.SqlPackage) switch sqlpkg { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{} } diff --git a/internal/codegen/golang/opts/enum.go b/internal/codegen/golang/opts/enum.go new file mode 100644 index 0000000000..40457d040a --- /dev/null +++ b/internal/codegen/golang/opts/enum.go @@ -0,0 +1,64 @@ +package opts + +import "fmt" + +type SQLDriver string + +const ( + SQLPackagePGXV4 string = "pgx/v4" + SQLPackagePGXV5 string = "pgx/v5" + SQLPackageStandard string = "database/sql" +) + +var validPackages = map[string]struct{}{ + string(SQLPackagePGXV4): {}, + string(SQLPackagePGXV5): {}, + string(SQLPackageStandard): {}, +} + +func validatePackage(sqlPackage string) error { + if _, found := validPackages[sqlPackage]; !found { + return fmt.Errorf("unknown SQL package: %s", sqlPackage) + } + return nil +} + +const ( + SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4" + SQLDriverPGXV5 = "github.com/jackc/pgx/v5" + SQLDriverLibPQ = "github.com/lib/pq" + SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql" +) + +var validDrivers = map[string]struct{}{ + string(SQLDriverPGXV4): {}, + string(SQLDriverPGXV5): {}, + string(SQLDriverLibPQ): {}, + string(SQLDriverGoSQLDriverMySQL): {}, +} + +func validateDriver(sqlDriver string) error { + if _, found := validDrivers[sqlDriver]; !found { + return fmt.Errorf("unknown SQL driver: %s", sqlDriver) + } + return nil +} + +func (d SQLDriver) IsPGX() bool { + return d == SQLDriverPGXV4 || d == SQLDriverPGXV5 +} + +func (d SQLDriver) IsGoSQLDriverMySQL() bool { + return d == SQLDriverGoSQLDriverMySQL +} + +func (d SQLDriver) Package() string { + switch d { + case SQLDriverPGXV4: + return SQLPackagePGXV4 + case SQLDriverPGXV5: + return SQLPackagePGXV5 + default: + return SQLPackageStandard + } +} diff --git a/internal/codegen/golang/opts/options.go b/internal/codegen/golang/opts/options.go index 0b66975506..0e2a8562e5 100644 --- a/internal/codegen/golang/opts/options.go +++ b/internal/codegen/golang/opts/options.go @@ -94,6 +94,18 @@ func parseOpts(req *plugin.GenerateRequest) (*Options, error) { } } + if options.SqlPackage != "" { + if err := validatePackage(options.SqlPackage); err != nil { + return nil, fmt.Errorf("invalid options: %s", err) + } + } + + if options.SqlDriver != "" { + if err := validateDriver(options.SqlDriver); err != nil { + return nil, fmt.Errorf("invalid options: %s", err) + } + } + if options.QueryParameterLimit == nil { options.QueryParameterLimit = new(int32) *options.QueryParameterLimit = 1 diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index 855d5425c0..563cc09ab9 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -48,7 +48,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*int32" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Int4" } return "sql.NullInt32" @@ -60,7 +60,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*int64" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Int8" } return "sql.NullInt64" @@ -72,7 +72,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*int16" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Int2" } return "sql.NullInt16" @@ -84,7 +84,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*int32" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Int4" } return "sql.NullInt32" @@ -96,7 +96,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*int64" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Int8" } return "sql.NullInt64" @@ -108,7 +108,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*int16" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Int2" } return "sql.NullInt16" @@ -120,7 +120,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*float64" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Float8" } return "sql.NullFloat64" @@ -132,7 +132,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*float32" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Float4" } return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file @@ -160,18 +160,18 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*bool" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Bool" } return "sql.NullBool" case "json": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "[]byte" - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.JSON" - case SQLDriverLibPQ: + case opts.SQLDriverLibPQ: if notNull { return "json.RawMessage" } else { @@ -183,11 +183,11 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "jsonb": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "[]byte" - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.JSONB" - case SQLDriverLibPQ: + case opts.SQLDriverLibPQ: if notNull { return "json.RawMessage" } else { @@ -201,7 +201,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi return "[]byte" case "date": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Date" } if notNull { @@ -213,7 +213,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi return "sql.NullTime" case "pg_catalog.time": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Time" } if notNull { @@ -234,7 +234,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi return "sql.NullTime" case "pg_catalog.timestamp": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Timestamp" } if notNull { @@ -246,7 +246,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi return "sql.NullTime" case "pg_catalog.timestamptz", "timestamptz": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Timestamptz" } if notNull { @@ -264,13 +264,13 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*string" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Text" } return "sql.NullString" case "uuid": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.UUID" } if notNull { @@ -283,14 +283,14 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "inet": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: if notNull { return "netip.Addr" } return "*netip.Addr" - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Inet" - case SQLDriverLibPQ: + case opts.SQLDriverLibPQ: return "pqtype.Inet" default: return "interface{}" @@ -298,14 +298,14 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "cidr": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: if notNull { return "netip.Prefix" } return "*netip.Prefix" - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.CIDR" - case SQLDriverLibPQ: + case opts.SQLDriverLibPQ: return "pqtype.CIDR" default: return "interface{}" @@ -313,11 +313,11 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "macaddr", "macaddr8": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "net.HardwareAddr" - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Macaddr" - case SQLDriverLibPQ: + case opts.SQLDriverLibPQ: return "pqtype.Macaddr" default: return "interface{}" @@ -335,13 +335,13 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi if emitPointersForNull { return "*string" } - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Text" } return "sql.NullString" case "interval", "pg_catalog.interval": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Interval" } if notNull { @@ -354,9 +354,9 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "daterange": switch driver { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Daterange" - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Range[pgtype.Date]" default: return "interface{}" @@ -364,7 +364,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "datemultirange": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Multirange[pgtype.Range[pgtype.Date]]" default: return "interface{}" @@ -372,9 +372,9 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "tsrange": switch driver { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Tsrange" - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Range[pgtype.Timestamp]" default: return "interface{}" @@ -382,7 +382,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "tsmultirange": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Multirange[pgtype.Range[pgtype.Timestamp]]" default: return "interface{}" @@ -390,9 +390,9 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "tstzrange": switch driver { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Tstzrange" - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Range[pgtype.Timestamptz]" default: return "interface{}" @@ -400,7 +400,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "tstzmultirange": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Multirange[pgtype.Range[pgtype.Timestamptz]]" default: return "interface{}" @@ -408,9 +408,9 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "numrange": switch driver { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Numrange" - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Range[pgtype.Numeric]" default: return "interface{}" @@ -418,7 +418,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "nummultirange": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Multirange[pgtype.Range[pgtype.Numeric]]" default: return "interface{}" @@ -426,9 +426,9 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "int4range": switch driver { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Int4range" - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Range[pgtype.Int4]" default: return "interface{}" @@ -436,7 +436,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "int4multirange": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Multirange[pgtype.Range[pgtype.Int4]]" default: return "interface{}" @@ -444,9 +444,9 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "int8range": switch driver { - case SQLDriverPGXV4: + case opts.SQLDriverPGXV4: return "pgtype.Int8range" - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Range[pgtype.Int8]" default: return "interface{}" @@ -454,7 +454,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi case "int8multirange": switch driver { - case SQLDriverPGXV5: + case opts.SQLDriverPGXV5: return "pgtype.Multirange[pgtype.Range[pgtype.Int8]]" default: return "interface{}" @@ -467,26 +467,26 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi return "interface{}" case "bit", "varbit", "pg_catalog.bit", "pg_catalog.varbit": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Bits" } - if driver == SQLDriverPGXV4 { + if driver == opts.SQLDriverPGXV4 { return "pgtype.Varbit" } case "cid": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Uint32" } - if driver == SQLDriverPGXV4 { + if driver == opts.SQLDriverPGXV4 { return "pgtype.CID" } case "oid": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Uint32" } - if driver == SQLDriverPGXV4 { + if driver == opts.SQLDriverPGXV4 { return "pgtype.OID" } @@ -496,10 +496,10 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi } case "xid": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { return "pgtype.Uint32" } - if driver == SQLDriverPGXV4 { + if driver == opts.SQLDriverPGXV4 { return "pgtype.XID" } @@ -539,7 +539,7 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi } case "vector": - if driver == SQLDriverPGXV5 { + if driver == opts.SQLDriverPGXV5 { if emitPointersForNull { return "*pgvector.Vector" } else { diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index b82178686c..3b4fb2fa1a 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts" "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -15,7 +16,7 @@ type QueryValue struct { DBName string // The name of the field in the database. Only set if Struct==nil. Struct *Struct Typ string - SQLDriver SQLDriver + SQLDriver opts.SQLDriver // Column is kept so late in the generation process around to differentiate // between mysql slices and pg arrays diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/db/db.go b/internal/endtoend/testdata/golang_invalid_sql_driver/db/db.go new file mode 100644 index 0000000000..bdb151c184 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/db/models.go b/internal/endtoend/testdata/golang_invalid_sql_driver/db/models.go new file mode 100644 index 0000000000..a1065e0b7e --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/db/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package db + +import ( + "database/sql" +) + +type Foo struct { + Bar sql.NullString +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/db/query.sql.go b/internal/endtoend/testdata/golang_invalid_sql_driver/db/query.sql.go new file mode 100644 index 0000000000..242b7393be --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/db/query.sql.go @@ -0,0 +1,38 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: query.sql + +package db + +import ( + "context" + "database/sql" +) + +const selectFoo = `-- name: SelectFoo :many +SELECT bar FROM foo +` + +func (q *Queries) SelectFoo(ctx context.Context) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, selectFoo) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var bar sql.NullString + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/query.sql b/internal/endtoend/testdata/golang_invalid_sql_driver/query.sql new file mode 100644 index 0000000000..e32e926b32 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/query.sql @@ -0,0 +1,2 @@ +-- name: SelectFoo :many +SELECT * FROM foo; diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/schema.sql b/internal/endtoend/testdata/golang_invalid_sql_driver/schema.sql new file mode 100644 index 0000000000..1bd72529f8 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/schema.sql @@ -0,0 +1,3 @@ +CREATE TABLE foo( + bar text +); diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/sqlc.json b/internal/endtoend/testdata/golang_invalid_sql_driver/sqlc.json new file mode 100644 index 0000000000..6124f178d1 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/sqlc.json @@ -0,0 +1,16 @@ +{ + "version": "2", + "sql": [ + { + "schema": "schema.sql", + "queries": "query.sql", + "engine": "postgresql", + "gen": { + "go": { + "out": "db", + "sql_driver": "github.com/unknown/driver" + } + } + } + ] +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_driver/stderr.txt b/internal/endtoend/testdata/golang_invalid_sql_driver/stderr.txt new file mode 100644 index 0000000000..b71f130a2f --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_driver/stderr.txt @@ -0,0 +1,2 @@ +# package +error generating code: invalid options: unknown SQL driver: github.com/unknown/driver diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/db/db.go b/internal/endtoend/testdata/golang_invalid_sql_package/db/db.go new file mode 100644 index 0000000000..bdb151c184 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/db/models.go b/internal/endtoend/testdata/golang_invalid_sql_package/db/models.go new file mode 100644 index 0000000000..a1065e0b7e --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/db/models.go @@ -0,0 +1,13 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package db + +import ( + "database/sql" +) + +type Foo struct { + Bar sql.NullString +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/db/query.sql.go b/internal/endtoend/testdata/golang_invalid_sql_package/db/query.sql.go new file mode 100644 index 0000000000..242b7393be --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/db/query.sql.go @@ -0,0 +1,38 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: query.sql + +package db + +import ( + "context" + "database/sql" +) + +const selectFoo = `-- name: SelectFoo :many +SELECT bar FROM foo +` + +func (q *Queries) SelectFoo(ctx context.Context) ([]sql.NullString, error) { + rows, err := q.db.QueryContext(ctx, selectFoo) + if err != nil { + return nil, err + } + defer rows.Close() + var items []sql.NullString + for rows.Next() { + var bar sql.NullString + if err := rows.Scan(&bar); err != nil { + return nil, err + } + items = append(items, bar) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/query.sql b/internal/endtoend/testdata/golang_invalid_sql_package/query.sql new file mode 100644 index 0000000000..e32e926b32 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/query.sql @@ -0,0 +1,2 @@ +-- name: SelectFoo :many +SELECT * FROM foo; diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/schema.sql b/internal/endtoend/testdata/golang_invalid_sql_package/schema.sql new file mode 100644 index 0000000000..1bd72529f8 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/schema.sql @@ -0,0 +1,3 @@ +CREATE TABLE foo( + bar text +); diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/sqlc.json b/internal/endtoend/testdata/golang_invalid_sql_package/sqlc.json new file mode 100644 index 0000000000..a6c726061c --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/sqlc.json @@ -0,0 +1,16 @@ +{ + "version": "2", + "sql": [ + { + "schema": "schema.sql", + "queries": "query.sql", + "engine": "postgresql", + "gen": { + "go": { + "out": "db", + "sql_package": "pgx/5" + } + } + } + ] +} diff --git a/internal/endtoend/testdata/golang_invalid_sql_package/stderr.txt b/internal/endtoend/testdata/golang_invalid_sql_package/stderr.txt new file mode 100644 index 0000000000..249ae167b7 --- /dev/null +++ b/internal/endtoend/testdata/golang_invalid_sql_package/stderr.txt @@ -0,0 +1,2 @@ +# package +error generating code: invalid options: unknown SQL package: pgx/5