fix: failed to migrate

This commit is contained in:
ericprd 2025-03-07 19:29:34 +08:00
parent fdf572cd3d
commit dc1a4dafbc
51 changed files with 2146 additions and 2202 deletions

View File

@ -8,6 +8,7 @@ type Category struct {
ID string `gorm:"primaryKey" json:"id"` ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"` Name string `gorm:"not null" json:"name"`
Code string `gorm:"not null;unique" json:"code"` Code string `gorm:"not null;unique" json:"code"`
Sequence int `gorm:"default:null" json:"sequence"`
CreatedAt time.Time `gorm:"default:CURRENT_TIMESTAMP" json:"created_at"` CreatedAt time.Time `gorm:"default:CURRENT_TIMESTAMP" json:"created_at"`
UpdatedAt time.Time `gorm:"default:CURRENT_TIMESTAMP" json:"updated_at"` UpdatedAt time.Time `gorm:"default:CURRENT_TIMESTAMP" json:"updated_at"`
DeletedAt time.Time `gorm:"default:null" json:"deleted_at"` DeletedAt time.Time `gorm:"default:null" json:"deleted_at"`

4
go.mod
View File

@ -9,7 +9,7 @@ require (
go.uber.org/fx v1.23.0 go.uber.org/fx v1.23.0
golang.org/x/crypto v0.34.0 golang.org/x/crypto v0.34.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.4.7 gorm.io/driver/postgres v1.5.7
) )
require ( require (
@ -20,7 +20,7 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.2.0 // indirect github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect

59
go.sum
View File

@ -6,7 +6,6 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
@ -41,16 +40,12 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.2.0 h1:NdPpngX0Y6z6XDFKqmFQaE+bCtkqzvQIOt1wvBlAqs8= github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.2.0/go.mod h1:Ptn7zmohNsWEsdxRawMzk3gaKma2obW+NWTnKa0S4nk= github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.1.2/go.mod h1:2lpufsF5mRHO6SuZkm0fNYxM6SWHfvyFj62KwNzgels=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
@ -59,13 +54,8 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@ -81,7 +71,6 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc= github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc=
github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
@ -105,7 +94,6 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@ -114,8 +102,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw= go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw=
go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg= go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg=
@ -126,67 +112,30 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/crypto v0.34.0 h1:+/C6tk6rf/+t5DhUketUbD1aNGqiSX3j15Z6xuIDlBA= golang.org/x/crypto v0.34.0 h1:+/C6tk6rf/+t5DhUketUbD1aNGqiSX3j15Z6xuIDlBA=
golang.org/x/crypto v0.34.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/crypto v0.34.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.4.7 h1:J06jXZCNq7Pdf7LIPn8tZn9LsWjd81BRSKveKNr0ZfA= gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.4.7/go.mod h1:UJChCNLFKeBqQRE+HrkFUbKbq9idPXmTOk2u4Wok8S4= gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/gorm v1.24.2/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=

View File

@ -23,3 +23,5 @@ _testmain.go
.envrc .envrc
/.testdb /.testdb
.DS_Store

View File

@ -1,3 +1,83 @@
# 5.4.3 (August 5, 2023)
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
* Fix: pgxpool: background health check cannot overflow pool
* Fix: Check for nil in defer when sending batch (recover properly from panic)
* Fix: json scan of non-string pointer to pointer
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
* Optimization: improve text format numeric parsing (horpto)
# 5.4.2 (July 11, 2023)
* Fix: RowScanner errors are fatal to Rows
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
* Hstore text codec internal improvements (Evan Jones)
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
* Fix: Stop background reader as soon as possible.
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
# 5.4.1 (June 18, 2023)
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
# 5.4.0 (June 14, 2023)
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
* CancelRequest: don't try to read the reply (Nicola Murino)
* Fix: correctly handle bool type aliases (Wichert Akkerman)
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
* Fix: various hstore fixes and optimizations (Evan Jones)
* Fix: RowToStructByPos with embedded unexported struct
* Support different bool string representations (Lev Zakharov)
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
a callback.
* Fix: scanning a table type into a struct
* Fix: scan array of record to pointer to slice of struct
* Fix: handle null for json (Cemre Mengu)
* Batch Query callback is called even when there is an error
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
# 5.3.1 (February 27, 2023)
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
* Fix: sql.Scanner not being used in certain cases
* Add text format jsonpath support
* Fix: fake non-blocking read adaptive wait time
# 5.3.0 (February 11, 2023)
* Fix: json values work with sql.Scanner
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
* Fix: support scan into single dimensional arrays
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
* Fix: driver.Value representation of bytea should be []byte not string
* Fix: better handling of unregistered OIDs
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
* Fix: encode to json ignoring driver.Valuer
* Support sql.Scanner on renamed base type
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
* Fix: connect with multiple hostnames when one can't be resolved
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
* Fix: scanning json column into **string
* Multiple reductions in memory allocations
* Fake non-blocking read adapts its max wait time
* Improve CopyFrom performance and reduce memory usage
* Fix: encode []any to array
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
* Support v4 and v5 stdlib in same program
* Fix: text format array decoding with string of "NULL"
* Prefer binary format for arrays
# 5.2.0 (December 5, 2022) # 5.2.0 (December 5, 2022)
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov) * `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)

View File

@ -1,5 +1,5 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) [![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5)
![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg) [![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml)
# pgx - PostgreSQL Driver and Toolkit # pgx - PostgreSQL Driver and Toolkit
@ -88,7 +88,7 @@ See CONTRIBUTING.md for setup instructions.
## Supported Go and PostgreSQL Versions ## Supported Go and PostgreSQL Versions
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.18 and higher and PostgreSQL 11 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.19 and higher and PostgreSQL 11 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
## Version Policy ## Version Policy
@ -132,13 +132,38 @@ These adapters can be used with the tracelog package.
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) * [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) * [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
## 3rd Party Libraries with PGX Support ## 3rd Party Libraries with PGX Support
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
pgxmock is a mock library implementing pgx interfaces.
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
Library for scanning data from a database into Go structs and more. Library for scanning data from a database into Go structs and more.
### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
A carefully designed SQL client for making using SQL easier,
more productive, and less error-prone on Golang.
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
Adds GSSAPI / Kerberos authentication support. Adds GSSAPI / Kerberos authentication support.
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
Explicit data mapping and scanning library for Go structs and slices.
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
Type safe and flexible package for scanning database data into Go types.
Supports, structs, maps, slices and custom mapping functions.
### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
Code first migration library for native pgx (no database/sql abstraction).

View File

@ -21,13 +21,10 @@ type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received. // Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) { func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.fn = func(br BatchResults) error { qq.fn = func(br BatchResults) error {
rows, err := br.Query() rows, _ := br.Query()
if err != nil {
return err
}
defer rows.Close() defer rows.Close()
err = fn(rows) err := fn(rows)
if err != nil { if err != nil {
return err return err
} }
@ -142,7 +139,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
} }
commandTag, err := br.mrr.ResultReader().Close() commandTag, err := br.mrr.ResultReader().Close()
br.err = err if err != nil {
br.err = err
br.mrr.Close()
}
if br.conn.batchTracer != nil { if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
@ -228,7 +228,7 @@ func (br *batchResults) Close() error {
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil { if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br) err := br.b.queuedQueries[br.qqIdx].fn(br)
if err != nil && br.err == nil { if err != nil {
br.err = err br.err = err
} }
} else { } else {
@ -290,7 +290,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
results, err := br.pipeline.GetResults() results, err := br.pipeline.GetResults()
if err != nil { if err != nil {
br.err = err br.err = err
return pgconn.CommandTag{}, err return pgconn.CommandTag{}, br.err
} }
var commandTag pgconn.CommandTag var commandTag pgconn.CommandTag
switch results := results.(type) { switch results := results.(type) {
@ -309,7 +309,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
}) })
} }
return commandTag, err return commandTag, br.err
} }
// Query reads the results from the next query in the batch as if the query has been sent with Query. // Query reads the results from the next query in the batch as if the query has been sent with Query.
@ -384,24 +384,20 @@ func (br *pipelineBatchResults) Close() error {
} }
}() }()
if br.err != nil { if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
return br.err
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err br.err = br.lastRows.err
return br.err return br.err
} }
if br.closed { if br.closed {
return nil return br.err
} }
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil { if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br) err := br.b.queuedQueries[br.qqIdx].fn(br)
if err != nil && br.err == nil { if err != nil {
br.err = err br.err = err
} }
} else { } else {

View File

@ -178,7 +178,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "simple_protocol": case "simple_protocol":
defaultQueryExecMode = QueryExecModeSimpleProtocol defaultQueryExecMode = QueryExecModeSimpleProtocol
default: default:
return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s)
} }
} }
@ -194,20 +194,20 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
return connConfig, nil return connConfig, nil
} }
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig // ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig]
// does. In addition, it accepts the following options: // does. In addition, it accepts the following options:
// //
// default_query_exec_mode // - default_query_exec_mode.
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See // Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement". // QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
// //
// statement_cache_capacity // - statement_cache_capacity.
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode. // The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
// Default: 512. // Default: 512.
// //
// description_cache_capacity // - description_cache_capacity.
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode. // The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
// Default: 512. // Default: 512.
func ParseConfig(connString string) (*ConnConfig, error) { func ParseConfig(connString string) (*ConnConfig, error) {
return ParseConfigWithOptions(connString, ParseConfigOptions{}) return ParseConfigWithOptions(connString, ParseConfigOptions{})
} }
@ -382,11 +382,9 @@ func quoteIdentifier(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
} }
// Ping executes an empty sql statement against the *Conn // Ping delegates to the underlying *pgconn.PgConn.Ping.
// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
func (c *Conn) Ping(ctx context.Context) error { func (c *Conn) Ping(ctx context.Context) error {
_, err := c.Exec(ctx, ";") return c.pgConn.Ping(ctx)
return err
} }
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
@ -509,7 +507,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a
mrr := c.pgConn.Exec(ctx, sql) mrr := c.pgConn.Exec(ctx, sql)
for mrr.NextResult() { for mrr.NextResult() {
commandTag, err = mrr.ResultReader().Close() commandTag, _ = mrr.ResultReader().Close()
} }
err = mrr.Close() err = mrr.Close()
return commandTag, err return commandTag, err
@ -585,8 +583,10 @@ const (
QueryExecModeCacheDescribe QueryExecModeCacheDescribe
// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
// to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
// when the the database schema is modified concurrently. // statement description on the first round trip and then uses it to execute the query on the second round trip. This
// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
// even when the the database schema is modified concurrently.
QueryExecModeDescribeExec QueryExecModeDescribeExec
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
@ -648,6 +648,9 @@ type QueryRewriter interface {
// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
// is allowed to ignore the error returned from Query and handle it in Rows. // is allowed to ignore the error returned from Query and handle it in Rows.
// //
// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not
// return an error.
//
// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be
// collected before processing rather than processed while receiving each row. This avoids the possibility of the // collected before processing rather than processed while receiving each row. This avoids the possibility of the
// application processing rows from a query that the server rejected. The CollectRows function is useful here. // application processing rows from a query that the server rejected. The CollectRows function is useful here.
@ -721,43 +724,10 @@ optionLoop:
sd, explicitPreparedStatement := c.preparedStatements[sql] sd, explicitPreparedStatement := c.preparedStatements[sql]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil { if sd == nil {
switch mode { sd, err = c.getStatementDescription(ctx, mode, sql)
case QueryExecModeCacheStatement: if err != nil {
if c.statementCache == nil { rows.fatal(err)
err = errDisabledStatementCache return rows, err
rows.fatal(err)
return rows, err
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
err = errDisabledDescriptionCache
rows.fatal(err)
return rows, err
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
} }
} }
@ -827,6 +797,48 @@ optionLoop:
return rows, rows.err return rows, rows.err
} }
// getStatementDescription returns the statement description of the sql query
// according to the given mode.
//
// If the mode is one that doesn't require to know the param and result OIDs
// then nil is returned without error.
func (c *Conn) getStatementDescription(
ctx context.Context,
mode QueryExecMode,
sql string,
) (sd *pgconn.StatementDescription, err error) {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return nil, errDisabledStatementCache
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
return nil, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return nil, errDisabledDescriptionCache
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
return c.Prepare(ctx, "", sql)
}
return sd, err
}
// QueryRow is a convenience wrapper over Query. Any error that occurs while // QueryRow is a convenience wrapper over Query. Any error that occurs while
// querying is deferred until calling Scan on the returned Row. That Row will // querying is deferred until calling Scan on the returned Row. That Row will
// error with ErrNoRows if no rows are returned. // error with ErrNoRows if no rows are returned.
@ -966,7 +978,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil { if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
} }
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
@ -998,7 +1010,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil { if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
} }
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
@ -1052,7 +1064,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(context.Background()) pipeline := c.pgConn.StartPipeline(context.Background())
defer func() { defer func() {
if pbr.err != nil { if pbr != nil && pbr.err != nil {
pipeline.Close() pipeline.Close()
} }
}() }()
@ -1065,18 +1077,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
err := pipeline.Sync() err := pipeline.Sync()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
for _, sd := range distinctNewQueries { for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults() results, err := pipeline.GetResults()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
resultSD, ok := results.(*pgconn.StatementDescription) resultSD, ok := results.(*pgconn.StatementDescription)
if !ok { if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
} }
// Fill in the previously empty / pending statement descriptions. // Fill in the previously empty / pending statement descriptions.
@ -1086,12 +1098,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
results, err := pipeline.GetResults() results, err := pipeline.GetResults()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
_, ok := results.(*pgconn.PipelineSync) _, ok := results.(*pgconn.PipelineSync)
if !ok { if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
} }
} }
@ -1106,7 +1118,9 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
for _, bi := range b.queuedQueries { for _, bi := range b.queuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} // we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.query, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
if bi.sd.Name == "" { if bi.sd.Name == "" {
@ -1118,7 +1132,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
err := pipeline.Sync() err := pipeline.Sync()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
return &pipelineBatchResults{ return &pipelineBatchResults{
@ -1272,6 +1286,8 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com
rows, _ := c.Query(ctx, `select attname, atttypid rows, _ := c.Query(ctx, `select attname, atttypid
from pg_attribute from pg_attribute
where attrelid=$1 where attrelid=$1
and not attisdropped
and attnum > 0
order by attnum`, order by attnum`,
typrelid, typrelid,
) )
@ -1313,6 +1329,7 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
for _, sd := range invalidatedStatements { for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name) pipeline.SendDeallocate(sd.Name)
delete(c.preparedStatements, sd.Name)
} }
err := pipeline.Sync() err := pipeline.Sync()

View File

@ -85,6 +85,7 @@ type copyFrom struct {
columnNames []string columnNames []string
rowSrc CopyFromSource rowSrc CopyFromSource
readerErrChan chan error readerErrChan chan error
mode QueryExecMode
} }
func (ct *copyFrom) run(ctx context.Context) (int64, error) { func (ct *copyFrom) run(ctx context.Context) (int64, error) {
@ -105,9 +106,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
} }
quotedColumnNames := cbuf.String() quotedColumnNames := cbuf.String()
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) var sd *pgconn.StatementDescription
if err != nil { switch ct.mode {
return 0, err case QueryExecModeExec, QueryExecModeSimpleProtocol:
// These modes don't support the binary format. Before the inclusion of the
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
// the OIDs. These prepared statements were not cached.
//
// Since that's the same behavior provided by QueryExecModeDescribeExec,
// we'll default to that mode.
ct.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var err error
sd, err = ct.conn.getStatementDescription(
ctx,
ct.mode,
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
)
if err != nil {
return 0, fmt.Errorf("statement description failed: %w", err)
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
} }
r, w := io.Pipe() r, w := io.Pipe()
@ -167,8 +188,13 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
} }
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0
for ct.rowSrc.Next() { for ct.rowSrc.Next() {
lastBufLen = len(buf)
values, err := ct.rowSrc.Values() values, err := ct.rowSrc.Values()
if err != nil { if err != nil {
return false, nil, err return false, nil, err
@ -185,7 +211,15 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
} }
} }
if len(buf) > 65536 { rowLen := len(buf) - lastBufLen
if rowLen > largestRowLen {
largestRowLen = rowLen
}
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
// 13, 65531, 13, 65531, 13.
if len(buf) > sendBufSize-largestRowLen {
return true, buf, nil return true, buf, nil
} }
} }
@ -208,6 +242,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [
columnNames: columnNames, columnNames: columnNames,
rowSrc: rowSrc, rowSrc: rowSrc,
readerErrChan: make(chan error), readerErrChan: make(chan error),
mode: c.config.DefaultQueryExecMode,
} }
return ct.run(ctx) return ct.run(ctx)

View File

@ -7,17 +7,17 @@ details.
Establishing a Connection Establishing a Connection
The primary way of establishing a connection is with `pgx.Connect`. The primary way of establishing a connection is with [pgx.Connect]:
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string.
Connection Pool Connection Pool
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package [*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
Query Interface Query Interface
@ -69,8 +69,9 @@ Use Exec to execute a query that does not return a result set.
PostgreSQL Data Types PostgreSQL Data Types
The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
including array and composite types. See that package's documentation for details. directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
require type registration. See that package's documentation for details.
Transactions Transactions

View File

@ -1,6 +1,7 @@
package pgx package pgx
import ( import (
"database/sql/driver"
"fmt" "fmt"
"github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/anynil"
@ -181,6 +182,19 @@ func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map,
} }
} }
} }
if !ok {
var dv driver.Valuer
if dv, ok = arg.(driver.Valuer); ok {
v, err := dv.Value()
if err != nil {
return err
}
dt, ok = m.TypeForValue(v)
if ok {
arg = v
}
}
}
if !ok { if !ok {
var str fmt.Stringer var str fmt.Stringer
if str, ok = arg.(fmt.Stringer); ok { if str, ok = arg.(fmt.Stringer); ok {

View File

@ -1,4 +1,7 @@
// Package iobufpool implements a global segregated-fit pool of buffers for IO. // Package iobufpool implements a global segregated-fit pool of buffers for IO.
//
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
package iobufpool package iobufpool
import "sync" import "sync"
@ -10,17 +13,27 @@ var pools [18]*sync.Pool
func init() { func init() {
for i := range pools { for i := range pools {
bufLen := 1 << (minPoolExpOf2 + i) bufLen := 1 << (minPoolExpOf2 + i)
pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }} pools[i] = &sync.Pool{
New: func() any {
buf := make([]byte, bufLen)
return &buf
},
}
} }
} }
// Get gets a []byte of len size with cap <= size*2. // Get gets a []byte of len size with cap <= size*2.
func Get(size int) []byte { func Get(size int) *[]byte {
i := getPoolIdx(size) i := getPoolIdx(size)
if i >= len(pools) { if i >= len(pools) {
return make([]byte, size) buf := make([]byte, size)
return &buf
} }
return pools[i].Get().([]byte)[:size]
ptrBuf := (pools[i].Get().(*[]byte))
*ptrBuf = (*ptrBuf)[:size]
return ptrBuf
} }
func getPoolIdx(size int) int { func getPoolIdx(size int) int {
@ -36,8 +49,8 @@ func getPoolIdx(size int) int {
} }
// Put returns buf to the pool. // Put returns buf to the pool.
func Put(buf []byte) { func Put(buf *[]byte) {
i := putPoolIdx(cap(buf)) i := putPoolIdx(cap(*buf))
if i < 0 { if i < 0 {
return return
} }

View File

@ -1,70 +0,0 @@
package nbconn
import (
"sync"
)
const minBufferQueueLen = 8
type bufferQueue struct {
lock sync.Mutex
queue [][]byte
r, w int
}
func (bq *bufferQueue) pushBack(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
bq.queue[bq.w] = buf
bq.w++
}
func (bq *bufferQueue) pushFront(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
bq.queue[bq.r] = buf
bq.w++
}
func (bq *bufferQueue) popFront() []byte {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.r == bq.w {
return nil
}
buf := bq.queue[bq.r]
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
bq.r++
if bq.r == bq.w {
bq.r = 0
bq.w = 0
if len(bq.queue) > minBufferQueueLen {
bq.queue = make([][]byte, minBufferQueueLen)
}
}
return buf
}
func (bq *bufferQueue) growQueue() {
desiredLen := (len(bq.queue) + 1) * 3 / 2
if desiredLen < minBufferQueueLen {
desiredLen = minBufferQueueLen
}
newQueue := make([][]byte, desiredLen)
copy(newQueue, bq.queue)
bq.queue = newQueue
}

View File

@ -1,478 +0,0 @@
// Package nbconn implements a non-blocking net.Conn wrapper.
//
// It is designed to solve three problems.
//
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
//
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
package nbconn
import (
"crypto/tls"
"errors"
"net"
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
var errClosed = errors.New("closed")
var ErrWouldBlock = new(wouldBlockError)
const fakeNonblockingWaitDuration = 100 * time.Millisecond
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
// mode.
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
// ignore all future calls.
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
type wouldBlockError struct{}
func (*wouldBlockError) Error() string {
return "would block"
}
func (*wouldBlockError) Timeout() bool { return true }
func (*wouldBlockError) Temporary() bool { return true }
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
// the underlying connection.
type Conn interface {
net.Conn
// Flush flushes any buffered writes.
Flush() error
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
BufferReadUntilBlock() error
}
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
type NetConn struct {
// 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit
// architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288 and
// https://github.com/jackc/pgx/issues/1307. Only access with atomics
closed int64 // 0 = not closed, 1 = closed
conn net.Conn
rawConn syscall.RawConn
readQueue bufferQueue
writeQueue bufferQueue
readFlushLock sync.Mutex
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
nonblockWriteBuf []byte
nonblockWriteErr error
nonblockWriteN int
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
writeDeadlineLock sync.Mutex
writeDeadline time.Time
}
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
nc := &NetConn{
conn: conn,
}
if !fakeNonBlockingIO {
if sc, ok := conn.(syscall.Conn); ok {
if rawConn, err := sc.SyscallConn(); err == nil {
nc.rawConn = rawConn
}
}
}
return nc
}
// Read implements io.Reader.
func (c *NetConn) Read(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return 0, err
}
for n < len(b) {
buf := c.readQueue.popFront()
if buf == nil {
break
}
copiedN := copy(b[n:], buf)
if copiedN < len(buf) {
buf = buf[copiedN:]
c.readQueue.pushFront(buf)
} else {
iobufpool.Put(buf)
}
n += copiedN
}
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
if n > 0 {
return n, nil
}
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
var readN int
if readNonblocking {
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.conn.Read(b[n:])
}
n += readN
return n, err
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
// closed. Call Flush to actually write to the underlying connection.
func (c *NetConn) Write(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
buf := iobufpool.Get(len(b))
copy(buf, b)
c.writeQueue.pushBack(buf)
return len(b), nil
}
func (c *NetConn) Close() (err error) {
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
if !swapped {
return errClosed
}
defer func() {
closeErr := c.conn.Close()
if err == nil {
err = closeErr
}
}()
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return err
}
return nil
}
func (c *NetConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *NetConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *NetConn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *NetConn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if c.readDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.readDeadline = t
return nil
}
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.conn.SetReadDeadline(t)
}
func (c *NetConn) SetWriteDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
if c.writeDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.writeDeadline = t
return nil
}
c.writeDeadline = t
return c.conn.SetWriteDeadline(t)
}
func (c *NetConn) Flush() error {
if c.isClosed() {
return errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
return c.flush()
}
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *NetConn) flush() error {
var stopChan chan struct{}
var errChan chan error
defer func() {
if stopChan != nil {
select {
case stopChan <- struct{}{}:
case <-errChan:
}
}
}()
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := buf
for len(remainingBuf) > 0 {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
return err
}
// Writing was blocked. Reading might unblock it.
if stopChan == nil {
stopChan, errChan = c.bufferNonblockingRead()
}
select {
case err := <-errChan:
stopChan = nil
return err
default:
}
}
}
iobufpool.Put(buf)
}
return nil
}
func (c *NetConn) BufferReadUntilBlock() error {
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(buf)
if n > 0 {
buf = buf[:n]
c.readQueue.pushBack(buf)
}
if err != nil {
if errors.Is(err, ErrWouldBlock) {
return nil
} else {
return err
}
}
}
}
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
stopChan = make(chan struct{})
errChan = make(chan error, 1)
go func() {
for {
err := c.BufferReadUntilBlock()
if err != nil {
errChan <- err
return
}
select {
case <-stopChan:
return
default:
}
}
}()
return stopChan, errChan
}
func (c *NetConn) isClosed() bool {
closed := atomic.LoadInt64(&c.closed)
return closed == 1
}
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingWrite(b)
} else {
return c.realNonblockingWrite(b)
}
}
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.conn.SetWriteDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetWriteDeadline(c.writeDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Write(b)
}
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingRead(b)
} else {
return c.realNonblockingRead(b)
}
}
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
err = c.conn.SetReadDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetReadDeadline(c.readDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Read(b)
}
// syscall.Conn is interface
// TLSClient establishes a TLS connection as a client over conn using config.
//
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
// *TLSConn is returned.
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
tc := tls.Client(conn, config)
err := tc.Handshake()
if err != nil {
return nil, err
}
// Ensure last written part of Handshake is actually sent.
err = conn.Flush()
if err != nil {
return nil, err
}
return &TLSConn{
tlsConn: tc,
nbConn: conn,
}, nil
}
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
// tls.Conn.
type TLSConn struct {
tlsConn *tls.Conn
nbConn *NetConn
}
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
func (tc *TLSConn) Close() error {
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
// own 5 second deadline then make all set deadlines no-op.
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
return tc.tlsConn.Close()
}
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }

View File

@ -1,13 +0,0 @@
//go:build !(aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris)
package nbconn
// Not using unix build tag for support on Go 1.18.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
return c.fakeNonblockingWrite(b)
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
return c.fakeNonblockingRead(b)
}

View File

@ -1,70 +0,0 @@
//go:build aix || android || darwin || dragonfly || freebsd || hurd || illumos || ios || linux || netbsd || openbsd || solaris
package nbconn
// Not using unix build tag for support on Go 1.18.
import (
"errors"
"io"
"syscall"
)
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
c.nonblockWriteBuf = b
c.nonblockWriteN = 0
c.nonblockWriteErr = nil
err = c.rawConn.Write(func(fd uintptr) (done bool) {
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
return true
})
n = c.nonblockWriteN
if err == nil && c.nonblockWriteErr != nil {
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockWriteErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
return n, nil
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
var funcErr error
err = c.rawConn.Read(func(fd uintptr) (done bool) {
n, funcErr = syscall.Read(int(fd), b)
return true
})
if err == nil && funcErr != nil {
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = funcErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
// syscall read did not return an error and 0 bytes were read means EOF.
if n == 0 {
return 0, io.EOF
}
return n, nil
}

View File

@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: sc.clientFirstMessage(), Data: sc.clientFirstMessage(),
} }
c.frontend.Send(saslInitialResponse) c.frontend.Send(saslInitialResponse)
err = c.frontend.Flush() err = c.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
return err return err
} }
@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: []byte(sc.clientFinalMessage()), Data: []byte(sc.clientFinalMessage()),
} }
c.frontend.Send(saslResponse) c.frontend.Send(saslResponse)
err = c.frontend.Flush() err = c.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
return err return err
} }

View File

@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"net" "net"
"net/url" "net/url"
@ -27,7 +26,7 @@ type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
type GetSSLPasswordFunc func(ctx context.Context) string type GetSSLPasswordFunc func(ctx context.Context) string
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// manually initialized Config will cause ConnectConfig to panic. // manually initialized Config will cause ConnectConfig to panic.
type Config struct { type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
@ -211,9 +210,9 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// //
// In addition, ParseConfig accepts the following options: // In addition, ParseConfig accepts the following options:
// //
// servicefile // - servicefile.
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string. // part of the connection string.
func ParseConfig(connString string) (*Config, error) { func ParseConfig(connString string) (*Config, error) {
var parseConfigOptions ParseConfigOptions var parseConfigOptions ParseConfigOptions
return ParseConfigWithOptions(connString, parseConfigOptions) return ParseConfigWithOptions(connString, parseConfigOptions)
@ -687,7 +686,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
caPath := sslrootcert caPath := sslrootcert
caCert, err := ioutil.ReadFile(caPath) caCert, err := os.ReadFile(caPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err) return nil, fmt.Errorf("unable to read CA file: %w", err)
} }
@ -705,7 +704,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
} }
if sslcert != "" && sslkey != "" { if sslcert != "" && sslkey != "" {
buf, err := ioutil.ReadFile(sslkey) buf, err := os.ReadFile(sslkey)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err) return nil, fmt.Errorf("unable to read sslkey: %w", err)
} }
@ -744,7 +743,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
} else { } else {
pemKey = pem.EncodeToMemory(block) pemKey = pem.EncodeToMemory(block)
} }
certfile, err := ioutil.ReadFile(sslcert) certfile, err := os.ReadFile(sslcert)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err) return nil, fmt.Errorf("unable to read cert: %w", err)
} }

View File

@ -0,0 +1,139 @@
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
package bgreader
import (
"io"
"sync"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
const (
StatusStopped = iota
StatusRunning
StatusStopping
)
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
type BGReader struct {
r io.Reader
cond *sync.Cond
status int32
readResults []readResult
}
type readResult struct {
buf *[]byte
err error
}
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
// reader will stop automatically when the underlying reader returns an error.
func (r *BGReader) Start() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
r.status = StatusRunning
go r.bgRead()
case StatusRunning:
// no-op
case StatusStopping:
r.status = StatusRunning
}
}
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
// background reader is not running.
func (r *BGReader) Stop() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
// no-op
case StatusRunning:
r.status = StatusStopping
case StatusStopping:
// no-op
}
}
// Status returns the current status of the background reader.
func (r *BGReader) Status() int32 {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.status
}
func (r *BGReader) bgRead() {
keepReading := true
for keepReading {
buf := iobufpool.Get(8192)
n, err := r.r.Read(*buf)
*buf = (*buf)[:n]
r.cond.L.Lock()
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
if r.status == StatusStopping || err != nil {
r.status = StatusStopped
keepReading = false
}
r.cond.L.Unlock()
r.cond.Broadcast()
}
}
// Read implements the io.Reader interface.
func (r *BGReader) Read(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
if len(r.readResults) > 0 {
return r.readFromReadResults(p)
}
// There are no unread background read results and the background reader is stopped.
if r.status == StatusStopped {
return r.r.Read(p)
}
// Wait for results from the background reader
for len(r.readResults) == 0 {
r.cond.Wait()
}
return r.readFromReadResults(p)
}
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
buf := r.readResults[0].buf
var err error
n := copy(p, *buf)
if n == len(*buf) {
err = r.readResults[0].err
iobufpool.Put(buf)
if len(r.readResults) == 1 {
r.readResults = nil
} else {
r.readResults = r.readResults[1:]
}
} else {
*buf = (*buf)[n:]
r.readResults[0].buf = buf
}
return n, err
}
func New(r io.Reader) *BGReader {
return &BGReader{
r: r,
cond: &sync.Cond{
L: &sync.Mutex{},
},
}
}

View File

@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
Data: nextData, Data: nextData,
} }
c.frontend.Send(gssResponse) c.frontend.Send(gssResponse)
err = c.frontend.Flush() err = c.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
return err return err
} }

View File

@ -13,11 +13,12 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/iobufpool"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
) )
@ -65,17 +66,24 @@ type NotificationHandler func(*PgConn, *Notification)
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct { type PgConn struct {
conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection conn net.Conn
pid uint32 // backend pid pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte txStatus byte
frontend *pgproto3.Frontend frontend *pgproto3.Frontend
bgReader *bgreader.BGReader
slowWriteTimer *time.Timer
config *Config config *Config
status byte // One of connStatus* constants status byte // One of connStatus* constants
bufferingReceive bool
bufferingReceiveMux sync.Mutex
bufferingReceiveMsg pgproto3.BackendMessage
bufferingReceiveErr error
peekedMsg pgproto3.BackendMessage peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources // Reusable / preallocated resources
@ -89,7 +97,7 @@ type PgConn struct {
} }
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. // to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt.
func Connect(ctx context.Context, connString string) (*PgConn, error) { func Connect(ctx context.Context, connString string) (*PgConn, error) {
config, err := ParseConfig(connString) config, err := ParseConfig(connString)
if err != nil { if err != nil {
@ -100,7 +108,7 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
} }
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
// and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be // and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be
// used to cancel a connect attempt. // used to cancel a connect attempt.
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
config, err := ParseConfigWithOptions(connString, parseConfigOptions) config, err := ParseConfigWithOptions(connString, parseConfigOptions)
@ -112,7 +120,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio
} }
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
// ParseConfig. ctx can be used to cancel a connect attempt. // [ParseConfig]. ctx can be used to cancel a connect attempt.
// //
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq: // authentication error will terminate the chain of attempts (like libpq:
@ -146,12 +154,15 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer := false foundBestServer := false
var fallbackConfig *FallbackConfig var fallbackConfig *FallbackConfig
for _, fc := range fallbackConfigs { for i, fc := range fallbackConfigs {
// ConnectTimeout restricts the whole connection process. // ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 { if config.ConnectTimeout != 0 {
var cancel context.CancelFunc // create new context first time or when previous host was different
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) {
defer cancel() var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
} else { } else {
ctx = octx ctx = octx
} }
@ -166,7 +177,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
if pgerr.Code == ERRCODE_INVALID_PASSWORD || if pgerr.Code == ERRCODE_INVALID_PASSWORD ||
pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil ||
pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || pgerr.Code == ERRCODE_INVALID_CATALOG_NAME ||
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break break
@ -203,6 +214,8 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
var configs []*FallbackConfig var configs []*FallbackConfig
var lookupErrors []error
for _, fb := range fallbacks { for _, fb := range fallbacks {
// skip resolve for unix sockets // skip resolve for unix sockets
if isAbsolutePath(fb.Host) { if isAbsolutePath(fb.Host) {
@ -217,7 +230,8 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
ips, err := lookupFn(ctx, fb.Host) ips, err := lookupFn(ctx, fb.Host)
if err != nil { if err != nil {
return nil, err lookupErrors = append(lookupErrors, err)
continue
} }
for _, ip := range ips { for _, ip := range ips {
@ -242,11 +256,18 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
} }
} }
// See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all
// errors are reported.
if len(configs) == 0 && len(lookupErrors) > 0 {
return nil, lookupErrors[0]
}
return configs, nil return configs, nil
} }
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig,
ignoreNotPreferredErr bool) (*PgConn, error) { ignoreNotPreferredErr bool,
) (*PgConn, error) {
pgConn := new(PgConn) pgConn := new(PgConn)
pgConn.config = config pgConn.config = config
pgConn.cleanupDone = make(chan struct{}) pgConn.cleanupDone = make(chan struct{})
@ -257,14 +278,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err != nil { if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
} }
nbNetConn := nbconn.NewNetConn(netConn, false)
pgConn.conn = nbNetConn pgConn.conn = netConn
pgConn.contextWatcher = newContextWatcher(nbNetConn) pgConn.contextWatcher = newContextWatcher(netConn)
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
if fallbackConfig.TLSConfig != nil { if fallbackConfig.TLSConfig != nil {
nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil { if err != nil {
netConn.Close() netConn.Close()
@ -280,7 +300,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.parameterStatuses = make(map[string]string) pgConn.parameterStatuses = make(map[string]string)
pgConn.status = connStatusConnecting pgConn.status = connStatusConnecting
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
pgConn.slowWriteTimer.Stop()
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
startupMsg := pgproto3.StartupMessage{ startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber, ProtocolVersion: pgproto3.ProtocolVersionNumber,
@ -298,9 +321,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
} }
pgConn.frontend.Send(&startupMsg) pgConn.frontend.Send(&startupMsg)
if err := pgConn.frontend.Flush(); err != nil { if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: err} return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
} }
for { for {
@ -383,7 +406,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
) )
} }
func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
if err != nil { if err != nil {
return nil, err return nil, err
@ -398,17 +421,12 @@ func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, err
return nil, errors.New("server refused TLS connection") return nil, errors.New("server refused TLS connection")
} }
tlsConn, err := nbconn.TLSClient(conn, tlsConfig) return tls.Client(conn, tlsConfig), nil
if err != nil {
return nil, err
}
return tlsConn, nil
} }
func (pgConn *PgConn) txPasswordMessage(password string) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
return pgConn.frontend.Flush() return pgConn.flushWithPotentialWriteReadDeadlock()
} }
func hexMD5(s string) string { func hexMD5(s string) string {
@ -417,6 +435,24 @@ func hexMD5(s string) string {
return hex.EncodeToString(hash.Sum(nil)) return hex.EncodeToString(hash.Sum(nil))
} }
func (pgConn *PgConn) signalMessage() chan struct{} {
if pgConn.bufferingReceive {
panic("BUG: signalMessage when already in progress")
}
pgConn.bufferingReceive = true
pgConn.bufferingReceiveMux.Lock()
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
return ch
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@ -445,7 +481,8 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
err = &pgconnError{ err = &pgconnError{
msg: "receive message failed", msg: "receive message failed",
err: normalizeTimeoutError(ctx, err), err: normalizeTimeoutError(ctx, err),
safeToRetry: true} safeToRetry: true,
}
} }
return msg, err return msg, err
} }
@ -456,13 +493,25 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
return pgConn.peekedMsg, nil return pgConn.peekedMsg, nil
} }
msg, err := pgConn.frontend.Receive() var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
pgConn.bufferingReceiveMux.Lock()
msg = pgConn.bufferingReceiveMsg
err = pgConn.bufferingReceiveErr
pgConn.bufferingReceiveMux.Unlock()
pgConn.bufferingReceive = false
// If a timeout error happened in the background try the read again.
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.frontend.Receive()
}
if err != nil { if err != nil {
if errors.Is(err, nbconn.ErrWouldBlock) {
return nil, err
}
// Close on anything other than timeout error - everything else is fatal // Close on anything other than timeout error - everything else is fatal
var netErr net.Error var netErr net.Error
isNetErr := errors.As(err, &netErr) isNetErr := errors.As(err, &netErr)
@ -510,7 +559,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
return msg, nil return msg, nil
} }
// Conn returns the underlying net.Conn. This rarely necessary. // Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or
// writing then SyncConn should usually be called before Conn.
func (pgConn *PgConn) Conn() net.Conn { func (pgConn *PgConn) Conn() net.Conn {
return pgConn.conn return pgConn.conn
} }
@ -573,7 +623,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
// //
// See https://github.com/jackc/pgx/issues/637 // See https://github.com/jackc/pgx/issues/637
pgConn.frontend.Send(&pgproto3.Terminate{}) pgConn.frontend.Send(&pgproto3.Terminate{})
pgConn.frontend.Flush() pgConn.flushWithPotentialWriteReadDeadlock()
return pgConn.conn.Close() return pgConn.conn.Close()
} }
@ -600,7 +650,7 @@ func (pgConn *PgConn) asyncClose() {
pgConn.conn.SetDeadline(deadline) pgConn.conn.SetDeadline(deadline)
pgConn.frontend.Send(&pgproto3.Terminate{}) pgConn.frontend.Send(&pgproto3.Terminate{})
pgConn.frontend.Flush() pgConn.flushWithPotentialWriteReadDeadlock()
}() }()
} }
@ -775,7 +825,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{}) pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.frontend.Flush() err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return nil, err return nil, err
@ -848,9 +898,28 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be // the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance. // specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr() serverAddr := pgConn.conn.RemoteAddr()
cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) var serverNetwork string
var serverAddress string
if serverAddr.Network() == "unix" {
// for unix sockets, RemoteAddr() calls getpeername() which returns the name the
// server passed to bind(). For Postgres, this is always a relative path "./.s.PGSQL.5432"
// so connecting to it will fail. Fall back to the config's value
serverNetwork, serverAddress = NetworkAddress(pgConn.config.Host, pgConn.config.Port)
} else {
serverNetwork, serverAddress = serverAddr.Network(), serverAddr.String()
}
cancelConn, err := pgConn.config.DialFunc(ctx, serverNetwork, serverAddress)
if err != nil { if err != nil {
return err // In case of unix sockets, RemoteAddr() returns only the file part of the path. If the
// first connect failed, try the config.
if serverAddr.Network() != "unix" {
return err
}
serverNetwork, serverAddr := NetworkAddress(pgConn.config.Host, pgConn.config.Port)
cancelConn, err = pgConn.config.DialFunc(ctx, serverNetwork, serverAddr)
if err != nil {
return err
}
} }
defer cancelConn.Close() defer cancelConn.Close()
@ -868,17 +937,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
// Postgres will process the request and close the connection
// so when don't need to read the reply
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10
_, err = cancelConn.Write(buf) _, err = cancelConn.Write(buf)
if err != nil { return err
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return err
}
return nil
} }
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
@ -944,7 +1007,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
} }
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush() err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
@ -1055,7 +1118,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendExecute(&pgproto3.Execute{})
pgConn.frontend.SendSync(&pgproto3.Sync{}) pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.frontend.Flush() err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
result.concludeCommand(CommandTag{}, err) result.concludeCommand(CommandTag{}, err)
@ -1088,7 +1151,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// Send copy to command // Send copy to command
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush() err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
pgConn.unlock() pgConn.unlock()
@ -1144,85 +1207,91 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
} }
// Send copy to command // Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush() err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return CommandTag{}, err return CommandTag{}, err
} }
err = pgConn.conn.SetReadDeadline(nbconn.NonBlockingDeadline) // Send copy data
if err != nil { abortCopyChan := make(chan struct{})
pgConn.asyncClose() copyErrChan := make(chan error, 1)
return CommandTag{}, err signalMessageChan := pgConn.signalMessage()
} var wg sync.WaitGroup
nonblocking := true wg.Add(1)
defer func() {
if nonblocking { go func() {
pgConn.conn.SetReadDeadline(time.Time{}) defer wg.Done()
buf := iobufpool.Get(65536)
defer iobufpool.Put(buf)
(*buf)[0] = 'd'
for {
n, readErr := r.Read((*buf)[5:cap(*buf)])
if n > 0 {
*buf = (*buf)[0 : n+5]
pgio.SetInt32((*buf)[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not
// setting pgConn.status or closing pgConn.cleanupDone for the same reason.
pgConn.conn.Close()
copyErrChan <- writeErr
return
}
}
if readErr != nil {
copyErrChan <- readErr
return
}
select {
case <-abortCopyChan:
return
default:
}
} }
}() }()
buf := iobufpool.Get(65536) var pgErr error
defer iobufpool.Put(buf) var copyErr error
buf[0] = 'd' for copyErr == nil && pgErr == nil {
select {
var readErr, pgErr error case copyErr = <-copyErrChan:
for pgErr == nil { case <-signalMessageChan:
// Read chunk from r. // If pgConn.receiveMessage encounters an error it will call pgConn.asyncClose. But that is a race condition with
var n int // the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an
n, readErr = r.Read(buf[5:cap(buf)]) // error is found then forcibly close the connection without sending the Terminate message.
if err := pgConn.bufferingReceiveErr; err != nil {
// Send chunk to PostgreSQL. pgConn.status = connStatusClosed
if n > 0 { pgConn.conn.Close()
buf = buf[0 : n+5] close(pgConn.cleanupDone)
pgio.SetInt32(buf[1:], int32(n+4))
writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf)
if writeErr != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
}
// Abort loop if there was a read error.
if readErr != nil {
break
}
// Read messages until error or none available.
for pgErr == nil {
msg, err := pgConn.receiveMessage()
if err != nil {
if errors.Is(err, nbconn.ErrWouldBlock) {
break
}
pgConn.asyncClose()
return CommandTag{}, normalizeTimeoutError(ctx, err) return CommandTag{}, normalizeTimeoutError(ctx, err)
} }
msg, _ := pgConn.receiveMessage()
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
break default:
signalMessageChan = pgConn.signalMessage()
} }
} }
} }
close(abortCopyChan)
// Make sure io goroutine finishes before writing.
wg.Wait()
err = pgConn.conn.SetReadDeadline(time.Time{}) if copyErr == io.EOF || pgErr != nil {
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
}
nonblocking = false
if readErr == io.EOF || pgErr != nil {
pgConn.frontend.Send(&pgproto3.CopyDone{}) pgConn.frontend.Send(&pgproto3.CopyDone{})
} else { } else {
pgConn.frontend.Send(&pgproto3.CopyFail{Message: readErr.Error()}) pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
} }
err = pgConn.frontend.Flush() err = pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
pgConn.asyncClose() pgConn.asyncClose()
return CommandTag{}, err return CommandTag{}, err
@ -1274,7 +1343,6 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := mrr.pgConn.receiveMessage() msg, err := mrr.pgConn.receiveMessage()
if err != nil { if err != nil {
mrr.pgConn.contextWatcher.Unwatch() mrr.pgConn.contextWatcher.Unwatch()
mrr.err = normalizeTimeoutError(mrr.ctx, err) mrr.err = normalizeTimeoutError(mrr.ctx, err)
@ -1417,7 +1485,8 @@ func (rr *ResultReader) NextRow() bool {
} }
// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until
// the ResultReader is closed. // the ResultReader is closed. It may return nil (for example, if the query did not return a result set or an error was
// encountered.)
func (rr *ResultReader) FieldDescriptions() []FieldDescription { func (rr *ResultReader) FieldDescriptions() []FieldDescription {
return rr.fieldDescriptions return rr.fieldDescriptions
} }
@ -1583,6 +1652,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf) _, err := pgConn.conn.Write(batch.buf)
if err != nil { if err != nil {
multiResult.closed = true multiResult.closed = true
@ -1611,29 +1682,99 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) {
return strings.Replace(s, "'", "''", -1), nil return strings.Replace(s, "'", "''", -1), nil
} }
// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and // CheckConn checks the underlying connection without writing any bytes. This is currently implemented by doing a read
// buffering until the read would block or an error occurs. This can be used to check if the server has closed the // with a very short deadline. This can be useful because a TCP connection can be broken such that a write will appear
// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails // to succeed even though it will never actually reach the server. Reading immediately before a write will detect this
// condition. If this is done immediately before sending a query it reduces the chances a query will be sent that fails
// without the client knowing whether the server received it or not. // without the client knowing whether the server received it or not.
//
// Deprecated: CheckConn is deprecated in favor of Ping. CheckConn cannot detect all types of broken connections where
// the write would still appear to succeed. Prefer Ping unless on a high latency connection.
func (pgConn *PgConn) CheckConn() error { func (pgConn *PgConn) CheckConn() error {
err := pgConn.conn.BufferReadUntilBlock() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { defer cancel()
return err
_, err := pgConn.ReceiveMessage(ctx)
if err != nil {
if !Timeout(err) {
return err
}
} }
return nil return nil
} }
// Ping pings the server. This can be useful because a TCP connection can be broken such that a write will appear to
// succeed even though it will never actually reach the server. Pinging immediately before sending a query reduces the
// chances a query will be sent that fails without the client knowing whether the server received it or not.
func (pgConn *PgConn) Ping(ctx context.Context) error {
return pgConn.Exec(ctx, "-- ping").Close()
}
// makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory.
func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
return CommandTag{s: string(buf)} return CommandTag{s: string(buf)}
} }
// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously
// blocked writing to us.
func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
// The time to wait is somewhat arbitrary. A Write should only take as long as the syscall and memcpy to the OS
// outbound network buffer unless the buffer is full (which potentially is a block). It needs to be long enough for
// the normal case, but short enough not to kill performance if a block occurs.
//
// In addition, on Windows the default timer resolution is 15.6ms. So setting the timer to less than that is
// ineffective.
if pgConn.slowWriteTimer.Reset(15 * time.Millisecond) {
panic("BUG: slow write timer already active")
}
}
// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
// The state of the timer is not relevant upon exiting the potential slow write. It may both
// fire (due to a slow write), or not fire (due to a fast write).
_ = pgConn.slowWriteTimer.Stop()
pgConn.bgReader.Stop()
}
func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
err := pgConn.frontend.Flush()
return err
}
// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for
// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already
// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may
// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any
// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack().
//
// This should not be confused with the PostgreSQL protocol Sync message.
func (pgConn *PgConn) SyncConn(ctx context.Context) error {
for i := 0; i < 10; i++ {
if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 {
return nil
}
err := pgConn.Ping(ctx)
if err != nil {
return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err)
}
}
// This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as
// LISTEN/NOTIFY or log notifications such that we never can get an empty buffer.
return errors.New("SyncConn: conn never synchronized")
}
// HijackedConn is the result of hijacking a connection. // HijackedConn is the result of hijacking a connection.
// //
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility. // compatibility.
type HijackedConn struct { type HijackedConn struct {
Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection Conn net.Conn
PID uint32 // backend pid PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server ParameterStatuses map[string]string // parameters that have been reported by the server
@ -1642,9 +1783,9 @@ type HijackedConn struct {
Config *Config Config *Config
} }
// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. // Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately
// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the // before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish
// raw connection after that (e.g. a load balancer or proxy). // a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy).
// //
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility. // compatibility.
@ -1668,6 +1809,8 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of // Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of
// PgConn.Hijack. The connection must be in an idle state. // PgConn.Hijack. The connection must be in an idle state.
// //
// hc.Frontend is replaced by a new pgproto3.Frontend built by hc.Config.BuildFrontend.
//
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
// compatibility. // compatibility.
func Construct(hc *HijackedConn) (*PgConn, error) { func Construct(hc *HijackedConn) (*PgConn, error) {
@ -1686,6 +1829,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
} }
pgConn.contextWatcher = newContextWatcher(pgConn.conn) pgConn.contextWatcher = newContextWatcher(pgConn.conn)
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
pgConn.slowWriteTimer.Stop()
pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn)
return pgConn, nil return pgConn, nil
} }
@ -1808,7 +1955,7 @@ func (p *Pipeline) Flush() error {
return errors.New("pipeline closed") return errors.New("pipeline closed")
} }
err := p.conn.frontend.Flush() err := p.conn.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
err = normalizeTimeoutError(p.ctx, err) err = normalizeTimeoutError(p.ctx, err)
@ -1887,7 +2034,6 @@ func (p *Pipeline) GetResults() (results any, err error) {
} }
} }
} }
func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {

View File

@ -196,7 +196,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
case AuthTypeCleartextPassword, AuthTypeMD5Password: case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough fallthrough
default: default:
// to maintain backwards compatability // to maintain backwards compatibility
msg = &PasswordMessage{} msg = &PasswordMessage{}
} }
case 'Q': case 'Q':
@ -233,11 +233,11 @@ func (b *Backend) Receive() (FrontendMessage, error) {
// contextual identification of FrontendMessages. For example, in the // contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage: // PG message flow documentation for PasswordMessage:
// //
// Byte1('p') // Byte1('p')
// //
// Identifies the message as a password response. Note that this is also used for // Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from // GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context. // the context.
// //
// Since the Frontend does not know about the state of a backend, it is important // Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend. // to call SetAuthType() after an authentication request is received by the Frontend.

View File

@ -14,7 +14,7 @@ import (
type chunkReader struct { type chunkReader struct {
r io.Reader r io.Reader
buf []byte buf *[]byte
rp, wp int // buf read position and write position rp, wp int // buf read position and write position
minBufSize int minBufSize int
@ -45,7 +45,7 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
func (r *chunkReader) Next(n int) (buf []byte, err error) { func (r *chunkReader) Next(n int) (buf []byte, err error) {
// Reset the buffer if it is empty // Reset the buffer if it is empty
if r.rp == r.wp { if r.rp == r.wp {
if len(r.buf) != r.minBufSize { if len(*r.buf) != r.minBufSize {
iobufpool.Put(r.buf) iobufpool.Put(r.buf)
r.buf = iobufpool.Get(r.minBufSize) r.buf = iobufpool.Get(r.minBufSize)
} }
@ -55,15 +55,15 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf // n bytes already in buf
if (r.wp - r.rp) >= n { if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n : r.rp+n] buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
r.rp += n r.rp += n
return buf, err return buf, err
} }
// buf is smaller than requested number of bytes // buf is smaller than requested number of bytes
if len(r.buf) < n { if len(*r.buf) < n {
bigBuf := iobufpool.Get(n) bigBuf := iobufpool.Get(n)
r.wp = copy(bigBuf, r.buf[r.rp:r.wp]) r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp])
r.rp = 0 r.rp = 0
iobufpool.Put(r.buf) iobufpool.Put(r.buf)
r.buf = bigBuf r.buf = bigBuf
@ -71,20 +71,20 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
// buf is large enough, but need to shift filled area to start to make enough contiguous space // buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp) minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount { if (len(*r.buf) - r.wp) < minReadCount {
r.wp = copy(r.buf, r.buf[r.rp:r.wp]) r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp])
r.rp = 0 r.rp = 0
} }
// Read at least the required number of bytes from the underlying io.Reader // Read at least the required number of bytes from the underlying io.Reader
readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount) readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount)
r.wp += readBytesCount r.wp += readBytesCount
// fmt.Println("read", n) // fmt.Println("read", n)
if err != nil { if err != nil {
return nil, err return nil, err
} }
buf = r.buf[r.rp : r.rp+n : r.rp+n] buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
r.rp += n r.rp += n
return buf, nil return buf, nil
} }

View File

@ -361,3 +361,7 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
func (f *Frontend) GetAuthType() uint32 { func (f *Frontend) GetAuthType() uint32 {
return f.authType return f.authType
} }
func (f *Frontend) ReadBufferLen() int {
return f.cr.wp - f.cr.rp
}

View File

@ -6,15 +6,18 @@ import (
"io" "io"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
// tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the // tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the
// format produced by the libpq C function PQtrace. // format produced by the libpq C function PQtrace.
type tracer struct { type tracer struct {
TracerOptions
mux sync.Mutex
w io.Writer w io.Writer
buf *bytes.Buffer buf *bytes.Buffer
TracerOptions
} }
// TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. // TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags.
@ -119,278 +122,255 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) {
case *Terminate: case *Terminate:
t.traceTerminate(sender, encodedLen, msg) t.traceTerminate(sender, encodedLen, msg)
default: default:
t.beginTrace(sender, encodedLen, "Unknown") t.writeTrace(sender, encodedLen, "Unknown", nil)
t.finishTrace()
} }
} }
func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) {
t.beginTrace(sender, encodedLen, "AuthenticationCleartextPassword") t.writeTrace(sender, encodedLen, "AuthenticationCleartextPassword", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) {
t.beginTrace(sender, encodedLen, "AuthenticationGSS") t.writeTrace(sender, encodedLen, "AuthenticationGSS", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) {
t.beginTrace(sender, encodedLen, "AuthenticationGSSContinue") t.writeTrace(sender, encodedLen, "AuthenticationGSSContinue", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) {
t.beginTrace(sender, encodedLen, "AuthenticationMD5Password") t.writeTrace(sender, encodedLen, "AuthenticationMD5Password", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) {
t.beginTrace(sender, encodedLen, "AuthenticationOk") t.writeTrace(sender, encodedLen, "AuthenticationOk", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) {
t.beginTrace(sender, encodedLen, "AuthenticationSASL") t.writeTrace(sender, encodedLen, "AuthenticationSASL", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) {
t.beginTrace(sender, encodedLen, "AuthenticationSASLContinue") t.writeTrace(sender, encodedLen, "AuthenticationSASLContinue", nil)
t.finishTrace()
} }
func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) {
t.beginTrace(sender, encodedLen, "AuthenticationSASLFinal") t.writeTrace(sender, encodedLen, "AuthenticationSASLFinal", nil)
t.finishTrace()
} }
func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) {
t.beginTrace(sender, encodedLen, "BackendKeyData") t.writeTrace(sender, encodedLen, "BackendKeyData", func() {
if t.RegressMode { if t.RegressMode {
t.buf.WriteString("\t NNNN NNNN") t.buf.WriteString("\t NNNN NNNN")
} else { } else {
fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey)
} }
t.finishTrace() })
} }
func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) {
t.beginTrace(sender, encodedLen, "Bind") t.writeTrace(sender, encodedLen, "Bind", func() {
fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes))
for _, fc := range msg.ParameterFormatCodes { for _, fc := range msg.ParameterFormatCodes {
fmt.Fprintf(t.buf, " %d", fc) fmt.Fprintf(t.buf, " %d", fc)
} }
fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) fmt.Fprintf(t.buf, " %d", len(msg.Parameters))
for _, p := range msg.Parameters { for _, p := range msg.Parameters {
fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p))
} }
fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes))
for _, fc := range msg.ResultFormatCodes { for _, fc := range msg.ResultFormatCodes {
fmt.Fprintf(t.buf, " %d", fc) fmt.Fprintf(t.buf, " %d", fc)
} }
t.finishTrace() })
} }
func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) {
t.beginTrace(sender, encodedLen, "BindComplete") t.writeTrace(sender, encodedLen, "BindComplete", nil)
t.finishTrace()
} }
func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) {
t.beginTrace(sender, encodedLen, "CancelRequest") t.writeTrace(sender, encodedLen, "CancelRequest", nil)
t.finishTrace()
} }
func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) {
t.beginTrace(sender, encodedLen, "Close") t.writeTrace(sender, encodedLen, "Close", nil)
t.finishTrace()
} }
func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) {
t.beginTrace(sender, encodedLen, "CloseComplete") t.writeTrace(sender, encodedLen, "CloseComplete", nil)
t.finishTrace()
} }
func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) {
t.beginTrace(sender, encodedLen, "CommandComplete") t.writeTrace(sender, encodedLen, "CommandComplete", func() {
fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag))
t.finishTrace() })
} }
func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) {
t.beginTrace(sender, encodedLen, "CopyBothResponse") t.writeTrace(sender, encodedLen, "CopyBothResponse", nil)
t.finishTrace()
} }
func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) {
t.beginTrace(sender, encodedLen, "CopyData") t.writeTrace(sender, encodedLen, "CopyData", nil)
t.finishTrace()
} }
func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) {
t.beginTrace(sender, encodedLen, "CopyDone") t.writeTrace(sender, encodedLen, "CopyDone", nil)
t.finishTrace()
} }
func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) {
t.beginTrace(sender, encodedLen, "CopyFail") t.writeTrace(sender, encodedLen, "CopyFail", func() {
fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message)))
t.finishTrace() })
} }
func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) {
t.beginTrace(sender, encodedLen, "CopyInResponse") t.writeTrace(sender, encodedLen, "CopyInResponse", nil)
t.finishTrace()
} }
func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) {
t.beginTrace(sender, encodedLen, "CopyOutResponse") t.writeTrace(sender, encodedLen, "CopyOutResponse", nil)
t.finishTrace()
} }
func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) {
t.beginTrace(sender, encodedLen, "DataRow") t.writeTrace(sender, encodedLen, "DataRow", func() {
fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) fmt.Fprintf(t.buf, "\t %d", len(msg.Values))
for _, v := range msg.Values { for _, v := range msg.Values {
if v == nil { if v == nil {
t.buf.WriteString(" -1") t.buf.WriteString(" -1")
} else { } else {
fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v))
}
} }
} })
t.finishTrace()
} }
func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) {
t.beginTrace(sender, encodedLen, "Describe") t.writeTrace(sender, encodedLen, "Describe", func() {
fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name)))
t.finishTrace() })
} }
func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) {
t.beginTrace(sender, encodedLen, "EmptyQueryResponse") t.writeTrace(sender, encodedLen, "EmptyQueryResponse", nil)
t.finishTrace()
} }
func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) {
t.beginTrace(sender, encodedLen, "ErrorResponse") t.writeTrace(sender, encodedLen, "ErrorResponse", nil)
t.finishTrace()
} }
func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) {
t.beginTrace(sender, encodedLen, "Execute") t.writeTrace(sender, encodedLen, "Execute", func() {
fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows)
t.finishTrace() })
} }
func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) {
t.beginTrace(sender, encodedLen, "Flush") t.writeTrace(sender, encodedLen, "Flush", nil)
t.finishTrace()
} }
func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) {
t.beginTrace(sender, encodedLen, "FunctionCall") t.writeTrace(sender, encodedLen, "FunctionCall", nil)
t.finishTrace()
} }
func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) {
t.beginTrace(sender, encodedLen, "FunctionCallResponse") t.writeTrace(sender, encodedLen, "FunctionCallResponse", nil)
t.finishTrace()
} }
func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) {
t.beginTrace(sender, encodedLen, "GSSEncRequest") t.writeTrace(sender, encodedLen, "GSSEncRequest", nil)
t.finishTrace()
} }
func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) {
t.beginTrace(sender, encodedLen, "NoData") t.writeTrace(sender, encodedLen, "NoData", nil)
t.finishTrace()
} }
func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) {
t.beginTrace(sender, encodedLen, "NoticeResponse") t.writeTrace(sender, encodedLen, "NoticeResponse", nil)
t.finishTrace()
} }
func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) {
t.beginTrace(sender, encodedLen, "NotificationResponse") t.writeTrace(sender, encodedLen, "NotificationResponse", func() {
fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload)))
t.finishTrace() })
} }
func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) {
t.beginTrace(sender, encodedLen, "ParameterDescription") t.writeTrace(sender, encodedLen, "ParameterDescription", nil)
t.finishTrace()
} }
func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) {
t.beginTrace(sender, encodedLen, "ParameterStatus") t.writeTrace(sender, encodedLen, "ParameterStatus", func() {
fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value)))
t.finishTrace() })
} }
func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) {
t.beginTrace(sender, encodedLen, "Parse") t.writeTrace(sender, encodedLen, "Parse", func() {
fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs))
for _, oid := range msg.ParameterOIDs { for _, oid := range msg.ParameterOIDs {
fmt.Fprintf(t.buf, " %d", oid) fmt.Fprintf(t.buf, " %d", oid)
} }
t.finishTrace() })
} }
func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) {
t.beginTrace(sender, encodedLen, "ParseComplete") t.writeTrace(sender, encodedLen, "ParseComplete", nil)
t.finishTrace()
} }
func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) {
t.beginTrace(sender, encodedLen, "PortalSuspended") t.writeTrace(sender, encodedLen, "PortalSuspended", nil)
t.finishTrace()
} }
func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) {
t.beginTrace(sender, encodedLen, "Query") t.writeTrace(sender, encodedLen, "Query", func() {
fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String)))
t.finishTrace() })
} }
func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) {
t.beginTrace(sender, encodedLen, "ReadyForQuery") t.writeTrace(sender, encodedLen, "ReadyForQuery", func() {
fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) fmt.Fprintf(t.buf, "\t %c", msg.TxStatus)
t.finishTrace() })
} }
func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) {
t.beginTrace(sender, encodedLen, "RowDescription") t.writeTrace(sender, encodedLen, "RowDescription", func() {
fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) fmt.Fprintf(t.buf, "\t %d", len(msg.Fields))
for _, fd := range msg.Fields { for _, fd := range msg.Fields {
fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format)
} }
t.finishTrace() })
} }
func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) {
t.beginTrace(sender, encodedLen, "SSLRequest") t.writeTrace(sender, encodedLen, "SSLRequest", nil)
t.finishTrace()
} }
func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) {
t.beginTrace(sender, encodedLen, "StartupMessage") t.writeTrace(sender, encodedLen, "StartupMessage", nil)
t.finishTrace()
} }
func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) {
t.beginTrace(sender, encodedLen, "Sync") t.writeTrace(sender, encodedLen, "Sync", nil)
t.finishTrace()
} }
func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) {
t.beginTrace(sender, encodedLen, "Terminate") t.writeTrace(sender, encodedLen, "Terminate", nil)
t.finishTrace()
} }
func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, writeDetails func()) {
t.mux.Lock()
defer t.mux.Unlock()
defer func() {
if t.buf.Cap() > 1024 {
t.buf = &bytes.Buffer{}
} else {
t.buf.Reset()
}
}()
if !t.SuppressTimestamps { if !t.SuppressTimestamps {
now := time.Now() now := time.Now()
t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000"))
@ -402,17 +382,13 @@ func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) {
t.buf.WriteString(msgType) t.buf.WriteString(msgType)
t.buf.WriteByte('\t') t.buf.WriteByte('\t')
t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10))
}
func (t *tracer) finishTrace() { if writeDetails != nil {
writeDetails()
}
t.buf.WriteByte('\n') t.buf.WriteByte('\n')
t.buf.WriteTo(t.w) t.buf.WriteTo(t.w)
if t.buf.Cap() > 1024 {
t.buf = &bytes.Buffer{}
} else {
t.buf.Reset()
}
} }
// traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to // traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to

View File

@ -5,7 +5,6 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"reflect"
"strconv" "strconv"
"strings" "strings"
"unicode" "unicode"
@ -363,38 +362,18 @@ func quoteArrayElement(src string) string {
} }
func isSpace(ch byte) bool { func isSpace(ch byte) bool {
// see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224 // see array_isspace:
return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f' // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c
return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f'
} }
func quoteArrayElementIfNeeded(src string) string { func quoteArrayElementIfNeeded(src string) string {
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) {
return quoteArrayElement(src) return quoteArrayElement(src)
} }
return src return src
} }
func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) {
switch value.Kind() {
case reflect.Array:
fallthrough
case reflect.Slice:
length := value.Len()
if 0 == elementsLength {
elementsLength = length
} else {
elementsLength *= length
}
dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1})
for i := 0; i < length; i++ {
if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok {
return d, l, true
}
}
}
return dimensions, elementsLength, true
}
// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves // Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves
// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. // PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed.
type Array[T any] struct { type Array[T any] struct {

View File

@ -47,7 +47,16 @@ func (c *ArrayCodec) FormatSupported(format int16) bool {
} }
func (c *ArrayCodec) PreferredFormat() int16 { func (c *ArrayCodec) PreferredFormat() int16 {
return c.ElementType.Codec.PreferredFormat() // The binary format should always be preferred for arrays if it is supported. Usually, this will happen automatically
// because most types that support binary prefer it. However, text, json, and jsonb support binary but prefer the text
// format. This is because it is simpler for jsonb and PostgreSQL can be significantly faster using the text format
// for text-like data types than binary. However, arrays appear to always be faster in binary.
//
// https://www.postgresql.org/message-id/CAMovtNoHFod2jMAKQjjxv209PCTJx5Kc66anwWvX0mEiaXwgmA%40mail.gmail.com
if c.ElementType.Codec.FormatSupported(BinaryFormatCode) {
return BinaryFormatCode
}
return TextFormatCode
} }
func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
@ -60,7 +69,9 @@ func (c *ArrayCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enc
elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType)
if elementEncodePlan == nil { if elementEncodePlan == nil {
return nil if reflect.TypeOf(elementType) != nil {
return nil
}
} }
switch format { switch format {
@ -301,7 +312,7 @@ func (c *ArrayCodec) decodeText(m *Map, arrayOID uint32, src []byte, array Array
for i, s := range uta.Elements { for i, s := range uta.Elements {
elem := array.ScanIndex(i) elem := array.ScanIndex(i)
var elemSrc []byte var elemSrc []byte
if s != "NULL" { if s != "NULL" || uta.Quoted[i] {
elemSrc = []byte(s) elemSrc = []byte(s)
} }

View File

@ -1,10 +1,12 @@
package pgtype package pgtype
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
) )
type BoolScanner interface { type BoolScanner interface {
@ -264,8 +266,8 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error {
return fmt.Errorf("cannot scan NULL into %T", dst) return fmt.Errorf("cannot scan NULL into %T", dst)
} }
if len(src) != 1 { if len(src) == 0 {
return fmt.Errorf("invalid length for bool: %v", len(src)) return fmt.Errorf("cannot scan empty string into %T", dst)
} }
p, ok := (dst).(*bool) p, ok := (dst).(*bool)
@ -273,7 +275,12 @@ func (scanPlanTextAnyToBool) Scan(src []byte, dst any) error {
return ErrScanTargetTypeChanged return ErrScanTargetTypeChanged
} }
*p = src[0] == 't' v, err := planTextToBool(src)
if err != nil {
return err
}
*p = v
return nil return nil
} }
@ -309,9 +316,28 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error {
return s.ScanBool(Bool{}) return s.ScanBool(Bool{})
} }
if len(src) != 1 { if len(src) == 0 {
return fmt.Errorf("invalid length for bool: %v", len(src)) return fmt.Errorf("cannot scan empty string into %T", dst)
} }
return s.ScanBool(Bool{Bool: src[0] == 't', Valid: true}) v, err := planTextToBool(src)
if err != nil {
return err
}
return s.ScanBool(Bool{Bool: v, Valid: true})
}
// https://www.postgresql.org/docs/11/datatype-boolean.html
func planTextToBool(src []byte) (bool, error) {
s := string(bytes.ToLower(bytes.TrimSpace(src)))
switch {
case strings.HasPrefix("true", s), strings.HasPrefix("yes", s), s == "on", s == "1":
return true, nil
case strings.HasPrefix("false", s), strings.HasPrefix("no", s), strings.HasPrefix("off", s), s == "0":
return false, nil
default:
return false, fmt.Errorf("unknown boolean string representation %q", src)
}
} }

View File

@ -910,3 +910,43 @@ func (a *anyMultiDimSliceArray) ScanIndexType() any {
} }
return reflect.New(lowestSliceType.Elem()).Interface() return reflect.New(lowestSliceType.Elem()).Interface()
} }
type anyArrayArrayReflect struct {
array reflect.Value
}
func (a anyArrayArrayReflect) Dimensions() []ArrayDimension {
return []ArrayDimension{{Length: int32(a.array.Len()), LowerBound: 1}}
}
func (a anyArrayArrayReflect) Index(i int) any {
return a.array.Index(i).Interface()
}
func (a anyArrayArrayReflect) IndexType() any {
return reflect.New(a.array.Type().Elem()).Elem().Interface()
}
func (a *anyArrayArrayReflect) SetDimensions(dimensions []ArrayDimension) error {
if dimensions == nil {
return fmt.Errorf("anyArrayArrayReflect: cannot scan NULL into %v", a.array.Type().String())
}
if len(dimensions) != 1 {
return fmt.Errorf("anyArrayArrayReflect: cannot scan multi-dimensional array into %v", a.array.Type().String())
}
if int(dimensions[0].Length) != a.array.Len() {
return fmt.Errorf("anyArrayArrayReflect: cannot scan array with length %v into %v", dimensions[0].Length, a.array.Type().String())
}
return nil
}
func (a *anyArrayArrayReflect) ScanIndex(i int) any {
return a.array.Index(i).Addr().Interface()
}
func (a *anyArrayArrayReflect) ScanIndexType() any {
return reflect.New(a.array.Type().Elem()).Interface()
}

View File

@ -238,7 +238,7 @@ func decodeHexBytea(src []byte) ([]byte, error) {
} }
func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c ByteaCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return codecDecodeToTextFormat(c, m, oid, format, src) return c.DecodeValue(m, oid, format, src)
} }
func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { func (c ByteaCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {

View File

@ -1,377 +1,9 @@
package pgtype package pgtype
import ( import (
"database/sql"
"fmt"
"math"
"reflect" "reflect"
"time"
) )
const (
maxUint = ^uint(0)
maxInt = int(maxUint >> 1)
minInt = -maxInt - 1
)
// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8
func underlyingNumberType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Int:
convVal := int(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int8:
convVal := int8(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int16:
convVal := int16(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int32:
convVal := int32(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Int64:
convVal := int64(refVal.Int())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint:
convVal := uint(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint8:
convVal := uint8(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint16:
convVal := uint16(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint32:
convVal := uint32(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Uint64:
convVal := uint64(refVal.Uint())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Float32:
convVal := float32(refVal.Float())
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.Float64:
convVal := refVal.Float()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
case reflect.String:
convVal := refVal.String()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
return nil, false
}
// underlyingBoolType gets the underlying type that can be converted to Bool
func underlyingBoolType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Bool:
convVal := refVal.Bool()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
return nil, false
}
// underlyingBytesType gets the underlying type that can be converted to []byte
func underlyingBytesType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Slice:
if refVal.Type().Elem().Kind() == reflect.Uint8 {
convVal := refVal.Bytes()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
}
return nil, false
}
// underlyingStringType gets the underlying type that can be converted to String
func underlyingStringType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.String:
convVal := refVal.String()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
return nil, false
}
// underlyingPtrType dereferences a pointer
func underlyingPtrType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
}
return nil, false
}
// underlyingTimeType gets the underlying type that can be converted to time.Time
func underlyingTimeType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
}
timeType := reflect.TypeOf(time.Time{})
if refVal.Type().ConvertibleTo(timeType) {
return refVal.Convert(timeType).Interface(), true
}
return nil, false
}
// underlyingUUIDType gets the underlying type that can be converted to [16]byte
func underlyingUUIDType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return time.Time{}, false
}
convVal := refVal.Elem().Interface()
return convVal, true
}
uuidType := reflect.TypeOf([16]byte{})
if refVal.Type().ConvertibleTo(uuidType) {
return refVal.Convert(uuidType).Interface(), true
}
return nil, false
}
// underlyingSliceType gets the underlying slice type
func underlyingSliceType(val any) (any, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Slice:
baseSliceType := reflect.SliceOf(refVal.Type().Elem())
if refVal.Type().ConvertibleTo(baseSliceType) {
convVal := refVal.Convert(baseSliceType)
return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type()
}
}
return nil, false
}
func int64AssignTo(srcVal int64, srcValid bool, dst any) error {
if srcValid {
switch v := dst.(type) {
case *int:
if srcVal < int64(minInt) {
return fmt.Errorf("%d is less than minimum value for int", srcVal)
} else if srcVal > int64(maxInt) {
return fmt.Errorf("%d is greater than maximum value for int", srcVal)
}
*v = int(srcVal)
case *int8:
if srcVal < math.MinInt8 {
return fmt.Errorf("%d is less than minimum value for int8", srcVal)
} else if srcVal > math.MaxInt8 {
return fmt.Errorf("%d is greater than maximum value for int8", srcVal)
}
*v = int8(srcVal)
case *int16:
if srcVal < math.MinInt16 {
return fmt.Errorf("%d is less than minimum value for int16", srcVal)
} else if srcVal > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for int16", srcVal)
}
*v = int16(srcVal)
case *int32:
if srcVal < math.MinInt32 {
return fmt.Errorf("%d is less than minimum value for int32", srcVal)
} else if srcVal > math.MaxInt32 {
return fmt.Errorf("%d is greater than maximum value for int32", srcVal)
}
*v = int32(srcVal)
case *int64:
if srcVal < math.MinInt64 {
return fmt.Errorf("%d is less than minimum value for int64", srcVal)
} else if srcVal > math.MaxInt64 {
return fmt.Errorf("%d is greater than maximum value for int64", srcVal)
}
*v = int64(srcVal)
case *uint:
if srcVal < 0 {
return fmt.Errorf("%d is less than zero for uint", srcVal)
} else if uint64(srcVal) > uint64(maxUint) {
return fmt.Errorf("%d is greater than maximum value for uint", srcVal)
}
*v = uint(srcVal)
case *uint8:
if srcVal < 0 {
return fmt.Errorf("%d is less than zero for uint8", srcVal)
} else if srcVal > math.MaxUint8 {
return fmt.Errorf("%d is greater than maximum value for uint8", srcVal)
}
*v = uint8(srcVal)
case *uint16:
if srcVal < 0 {
return fmt.Errorf("%d is less than zero for uint32", srcVal)
} else if srcVal > math.MaxUint16 {
return fmt.Errorf("%d is greater than maximum value for uint16", srcVal)
}
*v = uint16(srcVal)
case *uint32:
if srcVal < 0 {
return fmt.Errorf("%d is less than zero for uint32", srcVal)
} else if srcVal > math.MaxUint32 {
return fmt.Errorf("%d is greater than maximum value for uint32", srcVal)
}
*v = uint32(srcVal)
case *uint64:
if srcVal < 0 {
return fmt.Errorf("%d is less than zero for uint64", srcVal)
}
*v = uint64(srcVal)
case sql.Scanner:
return v.Scan(srcVal)
default:
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if dst is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
return int64AssignTo(srcVal, srcValid, el.Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if el.OverflowInt(int64(srcVal)) {
return fmt.Errorf("cannot put %d into %T", srcVal, dst)
}
el.SetInt(int64(srcVal))
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if srcVal < 0 {
return fmt.Errorf("%d is less than zero for %T", srcVal, dst)
}
if el.OverflowUint(uint64(srcVal)) {
return fmt.Errorf("cannot put %d into %T", srcVal, dst)
}
el.SetUint(uint64(srcVal))
return nil
}
}
return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
}
return nil
}
// if dst is a pointer to pointer and srcStatus is not Valid, nil it out
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
if el.Kind() == reflect.Ptr {
el.Set(reflect.Zero(el.Type()))
return nil
}
}
return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst)
}
func float64AssignTo(srcVal float64, srcValid bool, dst any) error {
if srcValid {
switch v := dst.(type) {
case *float32:
*v = float32(srcVal)
case *float64:
*v = srcVal
default:
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if dst is a type alias of a float32 or 64, set dst val
case reflect.Float32, reflect.Float64:
el.SetFloat(srcVal)
return nil
// if dst is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
return float64AssignTo(srcVal, srcValid, el.Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
i64 := int64(srcVal)
if float64(i64) == srcVal {
return int64AssignTo(i64, srcValid, dst)
}
}
}
return fmt.Errorf("cannot assign %v into %T", srcVal, dst)
}
return nil
}
// if dst is a pointer to pointer and srcStatus is not Valid, nil it out
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
if el.Kind() == reflect.Ptr {
el.Set(reflect.Zero(el.Type()))
return nil
}
}
return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcValid, dst)
}
func NullAssignTo(dst any) error { func NullAssignTo(dst any) error {
dstPtr := reflect.ValueOf(dst) dstPtr := reflect.ValueOf(dst)

View File

@ -57,27 +57,7 @@ JSON Support
pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types.
Array Support Extending Existing PostgreSQL Type Support
ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an
ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays.
Composite Support
CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of
the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and
CompositeIndexGetter.
Enum Support
PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can
reduce memory usage.
Array, Composite, and Enum Type Registration
Array, composite, and enum types can be easily registered from a pgx.Conn with the LoadType method.
Extending Existing Type Support
Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example, Generally, all Codecs will support interfaces that can be implemented to enable scanning and encoding. For example,
PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use PointCodec can use any Go type that implements the PointScanner and PointValuer interfaces. So rather than use
@ -90,11 +70,58 @@ pgx support such as github.com/shopspring/decimal. These types can be registered
logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example logic. See https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid for a example
integrations. integrations.
Entirely New Type Support New PostgreSQL Type Support
If the PostgreSQL type is not already supported then an OID / Codec mapping can be registered with Map.RegisterType. pgtype uses the PostgreSQL OID to determine how to encode or decode a value. pgtype supports array, composite, domain,
There is no difference between a Codec defined and registered by the application and a Codec built in to pgtype. See any and enum types. However, any type created in PostgreSQL with CREATE TYPE will receive a new OID. This means that the OID
of the Codecs in pgtype for Codec examples and for examples of type registration. of each new PostgreSQL type must be registered for pgtype to handle values of that type with the correct Codec.
The pgx.Conn LoadType method can return a *Type for array, composite, domain, and enum types by inspecting the database
metadata. This *Type can then be registered with Map.RegisterType.
For example, the following function could be called after a connection is established:
func RegisterDataTypes(ctx context.Context, conn *pgx.Conn) error {
dataTypeNames := []string{
"foo",
"_foo",
"bar",
"_bar",
}
for _, typeName := range dataTypeNames {
dataType, err := conn.LoadType(ctx, typeName)
if err != nil {
return err
}
conn.TypeMap().RegisterType(dataType)
}
return nil
}
A type cannot be registered unless all types it depends on are already registered. e.g. An array type cannot be
registered until its element type is registered.
ArrayCodec implements support for arrays. If pgtype supports type T then it can easily support []T by registering an
ArrayCodec for the appropriate PostgreSQL OID. In addition, Array[T] type can support multi-dimensional arrays.
CompositeCodec implements support for PostgreSQL composite types. Go structs can be scanned into if the public fields of
the struct are in the exact order and type of the PostgreSQL type or by implementing CompositeIndexScanner and
CompositeIndexGetter.
Domain types are treated as their underlying type if the underlying type and the domain type are registered.
PostgreSQL enums can usually be treated as text. However, EnumCodec implements support for interning strings which can
reduce memory usage.
While pgtype will often still work with unregistered types it is highly recommended that all types be registered due to
an improvement in performance and the elimination of certain edge cases.
If an entirely new PostgreSQL type (e.g. PostGIS types) is used then the application or a library can create a new
Codec. Then the OID / Codec mapping can be registered with Map.RegisterType. There is no difference between a Codec
defined and registered by the application and a Codec built in to pgtype. See any of the Codecs in pgtype for Codec
examples and for examples of type registration.
Encoding Unknown Types Encoding Unknown Types

View File

@ -1,14 +1,11 @@
package pgtype package pgtype
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"unicode"
"unicode/utf8"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -43,7 +40,7 @@ func (h *Hstore) Scan(src any) error {
switch src := src.(type) { switch src := src.(type) {
case string: case string:
return scanPlanTextAnyToHstoreScanner{}.Scan([]byte(src), h) return scanPlanTextAnyToHstoreScanner{}.scanString(src, h)
} }
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
@ -124,8 +121,15 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e
return nil, err return nil, err
} }
if hstore == nil { if len(hstore) == 0 {
return nil, nil // distinguish between empty and nil: Not strictly required by Postgres, since its protocol
// explicitly marks NULL column values separately. However, the Binary codec does this, and
// this means we can "round trip" Encode and Scan without data loss.
// nil: []byte(nil); empty: []byte{}
if hstore == nil {
return nil, nil
}
return []byte{}, nil
} }
firstPair := true firstPair := true
@ -134,16 +138,23 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e
if firstPair { if firstPair {
firstPair = false firstPair = false
} else { } else {
buf = append(buf, ',') buf = append(buf, ',', ' ')
} }
buf = append(buf, quoteHstoreElementIfNeeded(k)...) // unconditionally quote hstore keys/values like Postgres does
// this avoids a Mac OS X Postgres hstore parsing bug:
// https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com
buf = append(buf, '"')
buf = append(buf, quoteArrayReplacer.Replace(k)...)
buf = append(buf, '"')
buf = append(buf, "=>"...) buf = append(buf, "=>"...)
if v == nil { if v == nil {
buf = append(buf, "NULL"...) buf = append(buf, "NULL"...)
} else { } else {
buf = append(buf, quoteHstoreElementIfNeeded(*v)...) buf = append(buf, '"')
buf = append(buf, quoteArrayReplacer.Replace(*v)...)
buf = append(buf, '"')
} }
} }
@ -174,25 +185,28 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
scanner := (dst).(HstoreScanner) scanner := (dst).(HstoreScanner)
if src == nil { if src == nil {
return scanner.ScanHstore(Hstore{}) return scanner.ScanHstore(Hstore(nil))
} }
rp := 0 rp := 0
if len(src[rp:]) < 4 { const uint32Len = 4
if len(src[rp:]) < uint32Len {
return fmt.Errorf("hstore incomplete %v", src) return fmt.Errorf("hstore incomplete %v", src)
} }
pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) pairCount := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4 rp += uint32Len
hstore := make(Hstore, pairCount) hstore := make(Hstore, pairCount)
// one allocation for all *string, rather than one per string, just like text parsing
valueStrings := make([]string, pairCount)
for i := 0; i < pairCount; i++ { for i := 0; i < pairCount; i++ {
if len(src[rp:]) < 4 { if len(src[rp:]) < uint32Len {
return fmt.Errorf("hstore incomplete %v", src) return fmt.Errorf("hstore incomplete %v", src)
} }
keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) keyLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4 rp += uint32Len
if len(src[rp:]) < keyLen { if len(src[rp:]) < keyLen {
return fmt.Errorf("hstore incomplete %v", src) return fmt.Errorf("hstore incomplete %v", src)
@ -200,26 +214,17 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
key := string(src[rp : rp+keyLen]) key := string(src[rp : rp+keyLen])
rp += keyLen rp += keyLen
if len(src[rp:]) < 4 { if len(src[rp:]) < uint32Len {
return fmt.Errorf("hstore incomplete %v", src) return fmt.Errorf("hstore incomplete %v", src)
} }
valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4 rp += 4
var valueBuf []byte
if valueLen >= 0 { if valueLen >= 0 {
valueBuf = src[rp : rp+valueLen] valueStrings[i] = string(src[rp : rp+valueLen])
rp += valueLen rp += valueLen
}
var value Text hstore[key] = &valueStrings[i]
err := scanPlanTextAnyToTextScanner{}.Scan(valueBuf, &value)
if err != nil {
return err
}
if value.Valid {
hstore[key] = &value.String
} else { } else {
hstore[key] = nil hstore[key] = nil
} }
@ -230,28 +235,22 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error {
type scanPlanTextAnyToHstoreScanner struct{} type scanPlanTextAnyToHstoreScanner struct{}
func (scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error { func (s scanPlanTextAnyToHstoreScanner) Scan(src []byte, dst any) error {
scanner := (dst).(HstoreScanner) scanner := (dst).(HstoreScanner)
if src == nil { if src == nil {
return scanner.ScanHstore(Hstore{}) return scanner.ScanHstore(Hstore(nil))
} }
return s.scanString(string(src), scanner)
}
keys, values, err := parseHstore(string(src)) // scanString does not return nil hstore values because string cannot be nil.
func (scanPlanTextAnyToHstoreScanner) scanString(src string, scanner HstoreScanner) error {
hstore, err := parseHstore(src)
if err != nil { if err != nil {
return err return err
} }
return scanner.ScanHstore(hstore)
m := make(Hstore, len(keys))
for i := range keys {
if values[i].Valid {
m[keys[i]] = &values[i].String
} else {
m[keys[i]] = nil
}
}
return scanner.ScanHstore(m)
} }
func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { func (c HstoreCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
@ -271,191 +270,217 @@ func (c HstoreCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (
return hstore, nil return hstore, nil
} }
var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
func quoteHstoreElement(src string) string {
return `"` + quoteArrayReplacer.Replace(src) + `"`
}
func quoteHstoreElementIfNeeded(src string) string {
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) {
return quoteArrayElement(src)
}
return src
}
const (
hsPre = iota
hsKey
hsSep
hsVal
hsNul
hsNext
)
type hstoreParser struct { type hstoreParser struct {
str string str string
pos int pos int
nextBackslash int
} }
func newHSP(in string) *hstoreParser { func newHSP(in string) *hstoreParser {
return &hstoreParser{ return &hstoreParser{
pos: 0, pos: 0,
str: in, str: in,
nextBackslash: strings.IndexByte(in, '\\'),
} }
} }
func (p *hstoreParser) Consume() (r rune, end bool) { func (p *hstoreParser) atEnd() bool {
return p.pos >= len(p.str)
}
// consume returns the next byte of the string, or end if the string is done.
func (p *hstoreParser) consume() (b byte, end bool) {
if p.pos >= len(p.str) { if p.pos >= len(p.str) {
end = true return 0, true
return
} }
r, w := utf8.DecodeRuneInString(p.str[p.pos:]) b = p.str[p.pos]
p.pos += w p.pos++
return return b, false
} }
func (p *hstoreParser) Peek() (r rune, end bool) { func unexpectedByteErr(actualB byte, expectedB byte) error {
if p.pos >= len(p.str) { return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
end = true
return
}
r, _ = utf8.DecodeRuneInString(p.str[p.pos:])
return
} }
// parseHstore parses the string representation of an hstore column (the same // consumeExpectedByte consumes expectedB from the string, or returns an error.
// you would get from an ordinary SELECT) into two slices of keys and values. it func (p *hstoreParser) consumeExpectedByte(expectedB byte) error {
// is used internally in the default parsing of hstores. nextB, end := p.consume()
func parseHstore(s string) (k []string, v []Text, err error) { if end {
if s == "" { return fmt.Errorf("expected '%c' ('%#v'); found end", expectedB, expectedB)
return }
if nextB != expectedB {
return unexpectedByteErr(nextB, expectedB)
}
return nil
}
// consumeExpected2 consumes two expected bytes or returns an error.
// This was a bit faster than using a string argument (better inlining? Not sure).
func (p *hstoreParser) consumeExpected2(one byte, two byte) error {
if p.pos+2 > len(p.str) {
return errors.New("unexpected end of string")
}
if p.str[p.pos] != one {
return unexpectedByteErr(p.str[p.pos], one)
}
if p.str[p.pos+1] != two {
return unexpectedByteErr(p.str[p.pos+1], two)
}
p.pos += 2
return nil
}
var errEOSInQuoted = errors.New(`found end before closing double-quote ('"')`)
// consumeDoubleQuoted consumes a double-quoted string from p. The double quote must have been
// parsed already. This copies the string from the backing string so it can be garbage collected.
func (p *hstoreParser) consumeDoubleQuoted() (string, error) {
// fast path: assume most keys/values do not contain escapes
nextDoubleQuote := strings.IndexByte(p.str[p.pos:], '"')
if nextDoubleQuote == -1 {
return "", errEOSInQuoted
}
nextDoubleQuote += p.pos
if p.nextBackslash == -1 || p.nextBackslash > nextDoubleQuote {
// clone the string from the source string to ensure it can be garbage collected separately
// TODO: use strings.Clone on Go 1.20; this could get optimized away
s := strings.Clone(p.str[p.pos:nextDoubleQuote])
p.pos = nextDoubleQuote + 1
return s, nil
} }
buf := bytes.Buffer{} // slow path: string contains escapes
keys := []string{} s, err := p.consumeDoubleQuotedWithEscapes(p.nextBackslash)
values := []Text{} p.nextBackslash = strings.IndexByte(p.str[p.pos:], '\\')
if p.nextBackslash != -1 {
p.nextBackslash += p.pos
}
return s, err
}
// consumeDoubleQuotedWithEscapes consumes a double-quoted string containing escapes, starting
// at p.pos, and with the first backslash at firstBackslash. This copies the string so it can be
// garbage collected separately.
func (p *hstoreParser) consumeDoubleQuotedWithEscapes(firstBackslash int) (string, error) {
// copy the prefix that does not contain backslashes
var builder strings.Builder
builder.WriteString(p.str[p.pos:firstBackslash])
// skip to the backslash
p.pos = firstBackslash
// copy bytes until the end, unescaping backslashes
for {
nextB, end := p.consume()
if end {
return "", errEOSInQuoted
} else if nextB == '"' {
break
} else if nextB == '\\' {
// escape: skip the backslash and copy the char
nextB, end = p.consume()
if end {
return "", errEOSInQuoted
}
if !(nextB == '\\' || nextB == '"') {
return "", fmt.Errorf("unexpected escape in quoted string: found '%#v'", nextB)
}
builder.WriteByte(nextB)
} else {
// normal byte: copy it
builder.WriteByte(nextB)
}
}
return builder.String(), nil
}
// consumePairSeparator consumes the Hstore pair separator ", " or returns an error.
func (p *hstoreParser) consumePairSeparator() error {
return p.consumeExpected2(',', ' ')
}
// consumeKVSeparator consumes the Hstore key/value separator "=>" or returns an error.
func (p *hstoreParser) consumeKVSeparator() error {
return p.consumeExpected2('=', '>')
}
// consumeDoubleQuotedOrNull consumes the Hstore key/value separator "=>" or returns an error.
func (p *hstoreParser) consumeDoubleQuotedOrNull() (Text, error) {
// peek at the next byte
if p.atEnd() {
return Text{}, errors.New("found end instead of value")
}
next := p.str[p.pos]
if next == 'N' {
// must be the exact string NULL: use consumeExpected2 twice
err := p.consumeExpected2('N', 'U')
if err != nil {
return Text{}, err
}
err = p.consumeExpected2('L', 'L')
if err != nil {
return Text{}, err
}
return Text{String: "", Valid: false}, nil
} else if next != '"' {
return Text{}, unexpectedByteErr(next, '"')
}
// skip the double quote
p.pos += 1
s, err := p.consumeDoubleQuoted()
if err != nil {
return Text{}, err
}
return Text{String: s, Valid: true}, nil
}
func parseHstore(s string) (Hstore, error) {
p := newHSP(s) p := newHSP(s)
r, end := p.Consume() // This is an over-estimate of the number of key/value pairs. Use '>' because I am guessing it
state := hsPre // is less likely to occur in keys/values than '=' or ','.
numPairsEstimate := strings.Count(s, ">")
for !end { // makes one allocation of strings for the entire Hstore, rather than one allocation per value.
switch state { valueStrings := make([]string, 0, numPairsEstimate)
case hsPre: result := make(Hstore, numPairsEstimate)
if r == '"' { first := true
state = hsKey for !p.atEnd() {
} else { if !first {
err = errors.New("String does not begin with \"") err := p.consumePairSeparator()
} if err != nil {
case hsKey: return nil, err
switch r {
case '"': //End of the key
keys = append(keys, buf.String())
buf = bytes.Buffer{}
state = hsSep
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsSep:
if r == '=' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=', expecting '>'")
case r == '>':
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'")
case r == '"':
state = hsVal
case r == 'N':
state = hsNul
default:
err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r)
}
default:
err = fmt.Errorf("Invalid character after '=', expecting '>'")
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r)
}
case hsVal:
switch r {
case '"': //End of the value
values = append(values, Text{String: buf.String(), Valid: true})
buf = bytes.Buffer{}
state = hsNext
case '\\': //Potential escaped character
n, end := p.Consume()
switch {
case end:
err = errors.New("Found EOS in key, expecting character or \"")
case n == '"', n == '\\':
buf.WriteRune(n)
default:
buf.WriteRune(r)
buf.WriteRune(n)
}
default: //Any other character
buf.WriteRune(r)
}
case hsNul:
nulBuf := make([]rune, 3)
nulBuf[0] = r
for i := 1; i < 3; i++ {
r, end = p.Consume()
if end {
err = errors.New("Found EOS in NULL value")
return
}
nulBuf[i] = r
}
if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' {
values = append(values, Text{})
state = hsNext
} else {
err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf))
}
case hsNext:
if r == ',' {
r, end = p.Consume()
switch {
case end:
err = errors.New("Found EOS after ',', expcting space")
case (unicode.IsSpace(r)):
r, end = p.Consume()
state = hsKey
default:
err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r)
}
} else {
err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r)
} }
} else {
first = false
} }
err := p.consumeExpectedByte('"')
if err != nil { if err != nil {
return return nil, err
}
key, err := p.consumeDoubleQuoted()
if err != nil {
return nil, err
}
err = p.consumeKVSeparator()
if err != nil {
return nil, err
}
value, err := p.consumeDoubleQuotedOrNull()
if err != nil {
return nil, err
}
if value.Valid {
valueStrings = append(valueStrings, value.String)
result[key] = &valueStrings[len(valueStrings)-1]
} else {
result[key] = nil
} }
r, end = p.Consume()
} }
if state != hsNext {
err = errors.New("Improperly formatted hstore") return result, nil
return
}
k = keys
v = values
return
} }

View File

@ -33,7 +33,7 @@ func (dst *Int2) ScanInt64(n Int8) error {
} }
if n.Int64 < math.MinInt16 { if n.Int64 < math.MinInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) return fmt.Errorf("%d is less than minimum value for Int2", n.Int64)
} }
if n.Int64 > math.MaxInt16 { if n.Int64 > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64)
@ -593,7 +593,7 @@ func (dst *Int4) ScanInt64(n Int8) error {
} }
if n.Int64 < math.MinInt32 { if n.Int64 < math.MinInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) return fmt.Errorf("%d is less than minimum value for Int4", n.Int64)
} }
if n.Int64 > math.MaxInt32 { if n.Int64 > math.MaxInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64)
@ -1164,7 +1164,7 @@ func (dst *Int8) ScanInt64(n Int8) error {
} }
if n.Int64 < math.MinInt64 { if n.Int64 < math.MinInt64 {
return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) return fmt.Errorf("%d is less than minimum value for Int8", n.Int64)
} }
if n.Int64 > math.MaxInt64 { if n.Int64 > math.MaxInt64 {
return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64)

View File

@ -3,6 +3,7 @@ package pgtype
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
@ -34,7 +35,7 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error {
} }
if n.Int64 < math.MinInt<%= pg_bit_size %> { if n.Int64 < math.MinInt<%= pg_bit_size %> {
return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64)
} }
if n.Int64 > math.MaxInt<%= pg_bit_size %> { if n.Int64 > math.MaxInt<%= pg_bit_size %> {
return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64)

View File

@ -1,6 +1,7 @@
package pgtype package pgtype
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -23,6 +24,19 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
return encodePlanJSONCodecEitherFormatString{} return encodePlanJSONCodecEitherFormatString{}
case []byte: case []byte:
return encodePlanJSONCodecEitherFormatByteSlice{} return encodePlanJSONCodecEitherFormatByteSlice{}
// Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be
// marshalled.
//
// https://github.com/jackc/pgx/issues/1681
case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{}
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
//
// https://github.com/jackc/pgx/issues/1430
case driver.Valuer:
return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format}
} }
// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
@ -78,14 +92,36 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
switch target.(type) { switch target.(type) {
case *string: case *string:
return scanPlanAnyToString{} return scanPlanAnyToString{}
case **string:
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
// solution would be.
//
// https://github.com/jackc/pgx/issues/1470 -- **string
// https://github.com/jackc/pgx/issues/1691 -- ** anything else
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil {
if _, failed := nextPlan.(*scanPlanFail); !failed {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
}
}
case *[]byte: case *[]byte:
return scanPlanJSONToByteSlice{} return scanPlanJSONToByteSlice{}
case BytesScanner: case BytesScanner:
return scanPlanBinaryBytesToBytesScanner{} return scanPlanBinaryBytesToBytesScanner{}
default:
return scanPlanJSONToJSONUnmarshal{} // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
//
// https://github.com/jackc/pgx/issues/1418
case sql.Scanner:
return &scanPlanSQLScanner{formatCode: format}
} }
return scanPlanJSONToJSONUnmarshal{}
} }
type scanPlanAnyToString struct{} type scanPlanAnyToString struct{}
@ -125,7 +161,7 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if dstValue.Kind() == reflect.Ptr { if dstValue.Kind() == reflect.Ptr {
el := dstValue.Elem() el := dstValue.Elem()
switch el.Kind() { switch el.Kind() {
case reflect.Ptr, reflect.Slice, reflect.Map: case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface:
el.Set(reflect.Zero(el.Type())) el.Set(reflect.Zero(el.Type()))
return nil return nil
} }

View File

@ -33,23 +33,6 @@ var big10 *big.Int = big.NewInt(10)
var big100 *big.Int = big.NewInt(100) var big100 *big.Int = big.NewInt(100)
var big1000 *big.Int = big.NewInt(1000) var big1000 *big.Int = big.NewInt(1000)
var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8)
var bigMinInt8 *big.Int = big.NewInt(math.MinInt8)
var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16)
var bigMinInt16 *big.Int = big.NewInt(math.MinInt16)
var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32)
var bigMinInt32 *big.Int = big.NewInt(math.MinInt32)
var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64)
var bigMinInt64 *big.Int = big.NewInt(math.MinInt64)
var bigMaxInt *big.Int = big.NewInt(int64(maxInt))
var bigMinInt *big.Int = big.NewInt(int64(minInt))
var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8)
var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16)
var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32)
var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64))
var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint))
var bigNBase *big.Int = big.NewInt(nbase) var bigNBase *big.Int = big.NewInt(nbase)
var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
@ -161,20 +144,20 @@ func (n *Numeric) toBigInt() (*big.Int, error) {
} }
func parseNumericString(str string) (n *big.Int, exp int32, err error) { func parseNumericString(str string) (n *big.Int, exp int32, err error) {
parts := strings.SplitN(str, ".", 2) idx := strings.IndexByte(str, '.')
digits := strings.Join(parts, "")
if len(parts) > 1 { if idx == -1 {
exp = int32(-len(parts[1])) for len(str) > 1 && str[len(str)-1] == '0' && str[len(str)-2] != '-' {
} else { str = str[:len(str)-1]
for len(digits) > 1 && digits[len(digits)-1] == '0' && digits[len(digits)-2] != '-' {
digits = digits[:len(digits)-1]
exp++ exp++
} }
} else {
exp = int32(-(len(str) - idx - 1))
str = str[:idx] + str[idx+1:]
} }
accum := &big.Int{} accum := &big.Int{}
if _, ok := accum.SetString(digits, 10); !ok { if _, ok := accum.SetString(str, 10); !ok {
return nil, 0, fmt.Errorf("%s is not a number", str) return nil, 0, fmt.Errorf("%s is not a number", str)
} }
@ -240,10 +223,29 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
return n.numberTextBytes(), nil return n.numberTextBytes(), nil
} }
func (n *Numeric) UnmarshalJSON(src []byte) error {
if bytes.Equal(src, []byte(`null`)) {
*n = Numeric{}
return nil
}
if bytes.Equal(src, []byte(`"NaN"`)) {
*n = Numeric{NaN: true, Valid: true}
return nil
}
return scanPlanTextAnyToNumericScanner{}.Scan(src, n)
}
// numberString returns a string of the number. undefined if NaN, infinite, or NULL // numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte { func (n Numeric) numberTextBytes() []byte {
intStr := n.Int.String() intStr := n.Int.String()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if len(intStr) > 0 && intStr[:1] == "-" {
intStr = intStr[1:]
buf.WriteByte('-')
}
exp := int(n.Exp) exp := int(n.Exp)
if exp > 0 { if exp > 0 {
buf.WriteString(intStr) buf.WriteString(intStr)

View File

@ -44,7 +44,7 @@ const (
MacaddrOID = 829 MacaddrOID = 829
InetOID = 869 InetOID = 869
BoolArrayOID = 1000 BoolArrayOID = 1000
QCharArrayOID = 1003 QCharArrayOID = 1002
NameArrayOID = 1003 NameArrayOID = 1003
Int2ArrayOID = 1005 Int2ArrayOID = 1005
Int4ArrayOID = 1007 Int4ArrayOID = 1007
@ -104,6 +104,8 @@ const (
TstzrangeArrayOID = 3911 TstzrangeArrayOID = 3911
Int8rangeOID = 3926 Int8rangeOID = 3926
Int8rangeArrayOID = 3927 Int8rangeArrayOID = 3927
JSONPathOID = 4072
JSONPathArrayOID = 4073
Int4multirangeOID = 4451 Int4multirangeOID = 4451
NummultirangeOID = 4532 NummultirangeOID = 4532
TsmultirangeOID = 4533 TsmultirangeOID = 4533
@ -145,7 +147,7 @@ const (
BinaryFormatCode = 1 BinaryFormatCode = 1
) )
// A Codec converts between Go and PostgreSQL values. // A Codec converts between Go and PostgreSQL values. A Codec must not be mutated after it is registered with a Map.
type Codec interface { type Codec interface {
// FormatSupported returns true if the format is supported. // FormatSupported returns true if the format is supported.
FormatSupported(int16) bool FormatSupported(int16) bool
@ -176,6 +178,7 @@ func (e *nullAssignmentError) Error() string {
return fmt.Sprintf("cannot assign NULL to %T", e.dst) return fmt.Sprintf("cannot assign NULL to %T", e.dst)
} }
// Type represents a PostgreSQL data type. It must not be mutated after it is registered with a Map.
type Type struct { type Type struct {
Codec Codec Codec Codec
Name string Name string
@ -192,7 +195,8 @@ type Map struct {
reflectTypeToType map[reflect.Type]*Type reflectTypeToType map[reflect.Type]*Type
memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan
memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan
// TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every
// time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers // time a wrapper is found the PlanEncode method will be recursively called with the new value. This allows several layers of wrappers
@ -208,13 +212,16 @@ type Map struct {
} }
func NewMap() *Map { func NewMap() *Map {
m := &Map{ defaultMapInitOnce.Do(initDefaultMap)
return &Map{
oidToType: make(map[uint32]*Type), oidToType: make(map[uint32]*Type),
nameToType: make(map[string]*Type), nameToType: make(map[string]*Type),
reflectTypeToName: make(map[reflect.Type]string), reflectTypeToName: make(map[reflect.Type]string),
oidToFormatCode: make(map[uint32]int16), oidToFormatCode: make(map[uint32]int16),
memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan),
memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan),
TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{
TryWrapDerefPointerEncodePlan, TryWrapDerefPointerEncodePlan,
@ -223,6 +230,7 @@ func NewMap() *Map {
TryWrapStructEncodePlan, TryWrapStructEncodePlan,
TryWrapSliceEncodePlan, TryWrapSliceEncodePlan,
TryWrapMultiDimSliceEncodePlan, TryWrapMultiDimSliceEncodePlan,
TryWrapArrayEncodePlan,
}, },
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
@ -232,184 +240,12 @@ func NewMap() *Map {
TryWrapStructScanPlan, TryWrapStructScanPlan,
TryWrapPtrSliceScanPlan, TryWrapPtrSliceScanPlan,
TryWrapPtrMultiDimSliceScanPlan, TryWrapPtrMultiDimSliceScanPlan,
TryWrapPtrArrayScanPlan,
}, },
} }
// Base types
m.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
m.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}})
m.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}})
m.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}})
m.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}})
m.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}})
m.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}})
m.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}})
m.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}})
m.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}})
m.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}})
m.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}})
m.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}})
m.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}})
m.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
m.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
m.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
m.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
m.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
m.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
m.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
m.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
m.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
m.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}})
m.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}})
m.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}})
m.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}})
m.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}})
m.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
m.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
m.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}})
m.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}})
m.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}})
m.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})
m.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
m.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
// Range types
m.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: m.oidToType[DateOID]}})
m.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int4OID]}})
m.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: m.oidToType[Int8OID]}})
m.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[NumericOID]}})
m.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestampOID]}})
m.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: m.oidToType[TimestamptzOID]}})
// Multirange types
m.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[DaterangeOID]}})
m.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int4rangeOID]}})
m.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[Int8rangeOID]}})
m.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[NumrangeOID]}})
m.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TsrangeOID]}})
m.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: m.oidToType[TstzrangeOID]}})
// Array types
m.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ACLItemOID]}})
m.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BitOID]}})
m.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoolOID]}})
m.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BoxOID]}})
m.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[BPCharOID]}})
m.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[ByteaOID]}})
m.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[QCharOID]}})
m.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDOID]}})
m.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CIDROID]}})
m.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[CircleOID]}})
m.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DateOID]}})
m.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[DaterangeOID]}})
m.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float4OID]}})
m.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Float8OID]}})
m.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[InetOID]}})
m.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int2OID]}})
m.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4OID]}})
m.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int4rangeOID]}})
m.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8OID]}})
m.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[Int8rangeOID]}})
m.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[IntervalOID]}})
m.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONOID]}})
m.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[JSONBOID]}})
m.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LineOID]}})
m.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[LsegOID]}})
m.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[MacaddrOID]}})
m.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NameOID]}})
m.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumericOID]}})
m.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[NumrangeOID]}})
m.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[OIDOID]}})
m.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PathOID]}})
m.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PointOID]}})
m.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[PolygonOID]}})
m.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[RecordOID]}})
m.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TextOID]}})
m.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TIDOID]}})
m.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimeOID]}})
m.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestampOID]}})
m.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TimestamptzOID]}})
m.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TsrangeOID]}})
m.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[TstzrangeOID]}})
m.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[UUIDOID]}})
m.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarbitOID]}})
m.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[VarcharOID]}})
m.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: m.oidToType[XIDOID]}})
// Integer types that directly map to a PostgreSQL type
registerDefaultPgTypeVariants[int16](m, "int2")
registerDefaultPgTypeVariants[int32](m, "int4")
registerDefaultPgTypeVariants[int64](m, "int8")
// Integer types that do not have a direct match to a PostgreSQL type
registerDefaultPgTypeVariants[int8](m, "int8")
registerDefaultPgTypeVariants[int](m, "int8")
registerDefaultPgTypeVariants[uint8](m, "int8")
registerDefaultPgTypeVariants[uint16](m, "int8")
registerDefaultPgTypeVariants[uint32](m, "int8")
registerDefaultPgTypeVariants[uint64](m, "numeric")
registerDefaultPgTypeVariants[uint](m, "numeric")
registerDefaultPgTypeVariants[float32](m, "float4")
registerDefaultPgTypeVariants[float64](m, "float8")
registerDefaultPgTypeVariants[bool](m, "bool")
registerDefaultPgTypeVariants[time.Time](m, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](m, "interval")
registerDefaultPgTypeVariants[string](m, "text")
registerDefaultPgTypeVariants[[]byte](m, "bytea")
registerDefaultPgTypeVariants[net.IP](m, "inet")
registerDefaultPgTypeVariants[net.IPNet](m, "cidr")
registerDefaultPgTypeVariants[netip.Addr](m, "inet")
registerDefaultPgTypeVariants[netip.Prefix](m, "cidr")
// pgtype provided structs
registerDefaultPgTypeVariants[Bits](m, "varbit")
registerDefaultPgTypeVariants[Bool](m, "bool")
registerDefaultPgTypeVariants[Box](m, "box")
registerDefaultPgTypeVariants[Circle](m, "circle")
registerDefaultPgTypeVariants[Date](m, "date")
registerDefaultPgTypeVariants[Range[Date]](m, "daterange")
registerDefaultPgTypeVariants[Multirange[Range[Date]]](m, "datemultirange")
registerDefaultPgTypeVariants[Float4](m, "float4")
registerDefaultPgTypeVariants[Float8](m, "float8")
registerDefaultPgTypeVariants[Range[Float8]](m, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange.
registerDefaultPgTypeVariants[Multirange[Range[Float8]]](m, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange.
registerDefaultPgTypeVariants[Int2](m, "int2")
registerDefaultPgTypeVariants[Int4](m, "int4")
registerDefaultPgTypeVariants[Range[Int4]](m, "int4range")
registerDefaultPgTypeVariants[Multirange[Range[Int4]]](m, "int4multirange")
registerDefaultPgTypeVariants[Int8](m, "int8")
registerDefaultPgTypeVariants[Range[Int8]](m, "int8range")
registerDefaultPgTypeVariants[Multirange[Range[Int8]]](m, "int8multirange")
registerDefaultPgTypeVariants[Interval](m, "interval")
registerDefaultPgTypeVariants[Line](m, "line")
registerDefaultPgTypeVariants[Lseg](m, "lseg")
registerDefaultPgTypeVariants[Numeric](m, "numeric")
registerDefaultPgTypeVariants[Range[Numeric]](m, "numrange")
registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](m, "nummultirange")
registerDefaultPgTypeVariants[Path](m, "path")
registerDefaultPgTypeVariants[Point](m, "point")
registerDefaultPgTypeVariants[Polygon](m, "polygon")
registerDefaultPgTypeVariants[TID](m, "tid")
registerDefaultPgTypeVariants[Text](m, "text")
registerDefaultPgTypeVariants[Time](m, "time")
registerDefaultPgTypeVariants[Timestamp](m, "timestamp")
registerDefaultPgTypeVariants[Timestamptz](m, "timestamptz")
registerDefaultPgTypeVariants[Range[Timestamp]](m, "tsrange")
registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](m, "tsmultirange")
registerDefaultPgTypeVariants[Range[Timestamptz]](m, "tstzrange")
registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](m, "tstzmultirange")
registerDefaultPgTypeVariants[UUID](m, "uuid")
return m
} }
// RegisterType registers a data type with the Map. t must not be mutated after it is registered.
func (m *Map) RegisterType(t *Type) { func (m *Map) RegisterType(t *Type) {
m.oidToType[t.OID] = t m.oidToType[t.OID] = t
m.nameToType[t.Name] = t m.nameToType[t.Name] = t
@ -420,6 +256,9 @@ func (m *Map) RegisterType(t *Type) {
for k := range m.memoizedScanPlans { for k := range m.memoizedScanPlans {
delete(m.memoizedScanPlans, k) delete(m.memoizedScanPlans, k)
} }
for k := range m.memoizedEncodePlans {
delete(m.memoizedEncodePlans, k)
}
} }
// RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be // RegisterDefaultPgType registers a mapping of a Go type to a PostgreSQL type name. Typically the data type to be
@ -433,15 +272,27 @@ func (m *Map) RegisterDefaultPgType(value any, name string) {
for k := range m.memoizedScanPlans { for k := range m.memoizedScanPlans {
delete(m.memoizedScanPlans, k) delete(m.memoizedScanPlans, k)
} }
for k := range m.memoizedEncodePlans {
delete(m.memoizedEncodePlans, k)
}
} }
// TypeForOID returns the Type registered for the given OID. The returned Type must not be mutated.
func (m *Map) TypeForOID(oid uint32) (*Type, bool) { func (m *Map) TypeForOID(oid uint32) (*Type, bool) {
dt, ok := m.oidToType[oid] if dt, ok := m.oidToType[oid]; ok {
return dt, true
}
dt, ok := defaultMap.oidToType[oid]
return dt, ok return dt, ok
} }
// TypeForName returns the Type registered for the given name. The returned Type must not be mutated.
func (m *Map) TypeForName(name string) (*Type, bool) { func (m *Map) TypeForName(name string) (*Type, bool) {
dt, ok := m.nameToType[name] if dt, ok := m.nameToType[name]; ok {
return dt, true
}
dt, ok := defaultMap.nameToType[name]
return dt, ok return dt, ok
} }
@ -449,30 +300,39 @@ func (m *Map) buildReflectTypeToType() {
m.reflectTypeToType = make(map[reflect.Type]*Type) m.reflectTypeToType = make(map[reflect.Type]*Type)
for reflectType, name := range m.reflectTypeToName { for reflectType, name := range m.reflectTypeToName {
if dt, ok := m.nameToType[name]; ok { if dt, ok := m.TypeForName(name); ok {
m.reflectTypeToType[reflectType] = dt m.reflectTypeToType[reflectType] = dt
} }
} }
} }
// TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode // TypeForValue finds a data type suitable for v. Use RegisterType to register types that can encode and decode
// themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. // themselves. Use RegisterDefaultPgType to register that can be handled by a registered data type. The returned Type
// must not be mutated.
func (m *Map) TypeForValue(v any) (*Type, bool) { func (m *Map) TypeForValue(v any) (*Type, bool) {
if m.reflectTypeToType == nil { if m.reflectTypeToType == nil {
m.buildReflectTypeToType() m.buildReflectTypeToType()
} }
dt, ok := m.reflectTypeToType[reflect.TypeOf(v)] if dt, ok := m.reflectTypeToType[reflect.TypeOf(v)]; ok {
return dt, true
}
dt, ok := defaultMap.reflectTypeToType[reflect.TypeOf(v)]
return dt, ok return dt, ok
} }
// FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text // FormatCodeForOID returns the preferred format code for type oid. If the type is not registered it returns the text
// format code. // format code.
func (m *Map) FormatCodeForOID(oid uint32) int16 { func (m *Map) FormatCodeForOID(oid uint32) int16 {
fc, ok := m.oidToFormatCode[oid] if fc, ok := m.oidToFormatCode[oid]; ok {
if ok {
return fc return fc
} }
if fc, ok := defaultMap.oidToFormatCode[oid]; ok {
return fc
}
return TextFormatCode return TextFormatCode
} }
@ -573,6 +433,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error {
return plan.Scan(src, dst) return plan.Scan(src, dst)
} }
} }
for oid := range defaultMap.oidToType {
if _, ok := plan.m.oidToType[oid]; !ok {
plan := plan.m.planScan(oid, plan.formatCode, dst)
if _, ok := plan.(*scanPlanFail); !ok {
return plan.Scan(src, dst)
}
}
}
} }
var format string var format string
@ -586,7 +454,7 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error {
} }
var dataTypeName string var dataTypeName string
if t, ok := plan.m.oidToType[plan.oid]; ok { if t, ok := plan.m.TypeForOID(plan.oid); ok {
dataTypeName = t.Name dataTypeName = t.Name
} else { } else {
dataTypeName = "unknown type" dataTypeName = "unknown type"
@ -652,6 +520,7 @@ var elemKindToPointerTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]refl
reflect.Float32: reflect.TypeOf(new(float32)), reflect.Float32: reflect.TypeOf(new(float32)),
reflect.Float64: reflect.TypeOf(new(float64)), reflect.Float64: reflect.TypeOf(new(float64)),
reflect.String: reflect.TypeOf(new(string)), reflect.String: reflect.TypeOf(new(string)),
reflect.Bool: reflect.TypeOf(new(bool)),
} }
type underlyingTypeScanPlan struct { type underlyingTypeScanPlan struct {
@ -1018,7 +887,7 @@ func TryWrapStructScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValu
var targetElemValue reflect.Value var targetElemValue reflect.Value
if targetValue.IsNil() { if targetValue.IsNil() {
targetElemValue = reflect.New(targetValue.Type().Elem()) targetElemValue = reflect.Zero(targetValue.Type().Elem())
} else { } else {
targetElemValue = targetValue.Elem() targetElemValue = targetValue.Elem()
} }
@ -1075,15 +944,16 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa
return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true
} }
targetValue := reflect.ValueOf(target) targetType := reflect.TypeOf(target)
if targetValue.Kind() != reflect.Ptr { if targetType.Kind() != reflect.Ptr {
return nil, nil, false return nil, nil, false
} }
targetElemValue := targetValue.Elem() targetElemType := targetType.Elem()
if targetElemValue.Kind() == reflect.Slice { if targetElemType.Kind() == reflect.Slice {
return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: targetElemValue}, true slice := reflect.New(targetElemType).Elem()
return &wrapPtrSliceReflectScanPlan{}, &anySliceArrayReflect{slice: slice}, true
} }
return nil, nil, false return nil, nil, false
} }
@ -1139,6 +1009,31 @@ func (plan *wrapPtrMultiDimSliceScanPlan) Scan(src []byte, target any) error {
return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()}) return plan.next.Scan(src, &anyMultiDimSliceArray{slice: reflect.ValueOf(target).Elem()})
} }
// TryWrapPtrArrayScanPlan tries to wrap a pointer to a single dimension array.
func TryWrapPtrArrayScanPlan(target any) (plan WrappedScanPlanNextSetter, nextValue any, ok bool) {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return nil, nil, false
}
targetElemValue := targetValue.Elem()
if targetElemValue.Kind() == reflect.Array {
return &wrapPtrArrayReflectScanPlan{}, &anyArrayArrayReflect{array: targetElemValue}, true
}
return nil, nil, false
}
type wrapPtrArrayReflectScanPlan struct {
next ScanPlan
}
func (plan *wrapPtrArrayReflectScanPlan) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error {
return plan.next.Scan(src, &anyArrayArrayReflect{array: reflect.ValueOf(target).Elem()})
}
// PlanScan prepares a plan to scan a value into target. // PlanScan prepares a plan to scan a value into target.
func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {
oidMemo := m.memoizedScanPlans[oid] oidMemo := m.memoizedScanPlans[oid]
@ -1159,6 +1054,10 @@ func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan {
} }
func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan {
if target == nil {
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
}
if _, ok := target.(*UndecodedBytes); ok { if _, ok := target.(*UndecodedBytes); ok {
return scanPlanAnyToUndecodedBytes{} return scanPlanAnyToUndecodedBytes{}
} }
@ -1200,6 +1099,18 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan {
} }
} }
// This needs to happen before trying m.TryWrapScanPlanFuncs. Otherwise, a sql.Scanner would not get called if it was
// defined on a type that could be unwrapped such as `type myString string`.
//
// https://github.com/jackc/pgtype/issues/197
if _, ok := target.(sql.Scanner); ok {
if dt == nil {
return &scanPlanSQLScanner{formatCode: formatCode}
} else {
return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode}
}
}
for _, f := range m.TryWrapScanPlanFuncs { for _, f := range m.TryWrapScanPlanFuncs {
if wrapperPlan, nextDst, ok := f(target); ok { if wrapperPlan, nextDst, ok := f(target); ok {
if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil {
@ -1215,14 +1126,6 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan {
if _, ok := target.(*any); ok { if _, ok := target.(*any); ok {
return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode}
} }
if _, ok := target.(sql.Scanner); ok {
return &scanPlanCodecSQLScanner{c: dt.Codec, m: m, oid: oid, formatCode: formatCode}
}
}
if _, ok := target.(sql.Scanner); ok {
return &scanPlanSQLScanner{formatCode: formatCode}
} }
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
@ -1237,25 +1140,6 @@ func (m *Map) Scan(oid uint32, formatCode int16, src []byte, dst any) error {
return plan.Scan(src, dst) return plan.Scan(src, dst)
} }
func scanUnknownType(oid uint32, formatCode int16, buf []byte, dest any) error {
switch dest := dest.(type) {
case *string:
if formatCode == BinaryFormatCode {
return fmt.Errorf("unknown oid %d in binary format cannot be scanned into %T", oid, dest)
}
*dest = string(buf)
return nil
case *[]byte:
*dest = buf
return nil
default:
if nextDst, retry := GetAssignToDstType(dest); retry {
return scanUnknownType(oid, formatCode, buf, nextDst)
}
return fmt.Errorf("unknown oid %d cannot be scanned into %T", oid, dest)
}
}
var ErrScanTargetTypeChanged = errors.New("scan target type changed") var ErrScanTargetTypeChanged = errors.New("scan target type changed")
func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error { func codecScan(codec Codec, m *Map, oid uint32, format int16, src []byte, dst any) error {
@ -1289,6 +1173,24 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src
// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be // PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be
// found then nil is returned. // found then nil is returned.
func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan {
oidMemo := m.memoizedEncodePlans[oid]
if oidMemo == nil {
oidMemo = make(map[reflect.Type][2]EncodePlan)
m.memoizedEncodePlans[oid] = oidMemo
}
targetReflectType := reflect.TypeOf(value)
typeMemo := oidMemo[targetReflectType]
plan := typeMemo[format]
if plan == nil {
plan = m.planEncode(oid, format, value)
typeMemo[format] = plan
oidMemo[targetReflectType] = typeMemo
}
return plan
}
func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan {
if format == TextFormatCode { if format == TextFormatCode {
switch value.(type) { switch value.(type) {
case string: case string:
@ -1299,16 +1201,16 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan {
} }
var dt *Type var dt *Type
if dataType, ok := m.TypeForOID(oid); ok {
if oid == 0 { dt = dataType
} else {
// If no type for the OID was found, then either it is unknowable (e.g. the simple protocol) or it is an
// unregistered type. In either case try to find the type and OID that matches the value (e.g. a []byte would be
// registered to PostgreSQL bytea).
if dataType, ok := m.TypeForValue(value); ok { if dataType, ok := m.TypeForValue(value); ok {
dt = dataType dt = dataType
oid = dt.OID // Preserve assumed OID in case we are recursively called below. oid = dt.OID // Preserve assumed OID in case we are recursively called below.
} }
} else {
if dataType, ok := m.TypeForOID(oid); ok {
dt = dataType
}
} }
if dt != nil { if dt != nil {
@ -1453,6 +1355,7 @@ var kindToTypes map[reflect.Kind]reflect.Type = map[reflect.Kind]reflect.Type{
reflect.Float32: reflect.TypeOf(float32(0)), reflect.Float32: reflect.TypeOf(float32(0)),
reflect.Float64: reflect.TypeOf(float64(0)), reflect.Float64: reflect.TypeOf(float64(0)),
reflect.String: reflect.TypeOf(""), reflect.String: reflect.TypeOf(""),
reflect.Bool: reflect.TypeOf(false),
} }
type underlyingTypeEncodePlan struct { type underlyingTypeEncodePlan struct {
@ -1884,11 +1787,7 @@ type wrapSliceEncodePlan[T any] struct {
func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next } func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) { func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
w := anySliceArrayReflect{ return plan.next.Encode((FlatArray[T])(value.([]T)), buf)
slice: reflect.ValueOf(value),
}
return plan.next.Encode(w, buf)
} }
type wrapSliceEncodeReflectPlan struct { type wrapSliceEncodeReflectPlan struct {
@ -1941,6 +1840,35 @@ func (plan *wrapMultiDimSliceEncodePlan) Encode(value any, buf []byte) (newBuf [
return plan.next.Encode(&w, buf) return plan.next.Encode(&w, buf)
} }
func TryWrapArrayEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) {
if _, ok := value.(driver.Valuer); ok {
return nil, nil, false
}
if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Array {
w := anyArrayArrayReflect{
array: reflect.ValueOf(value),
}
return &wrapArrayEncodeReflectPlan{}, w, true
}
return nil, nil, false
}
type wrapArrayEncodeReflectPlan struct {
next EncodePlan
}
func (plan *wrapArrayEncodeReflectPlan) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapArrayEncodeReflectPlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
w := anyArrayArrayReflect{
array: reflect.ValueOf(value),
}
return plan.next.Encode(w, buf)
}
func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) error { func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) error {
var format string var format string
switch formatCode { switch formatCode {
@ -1953,13 +1881,13 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error)
} }
var dataTypeName string var dataTypeName string
if t, ok := m.oidToType[oid]; ok { if t, ok := m.TypeForOID(oid); ok {
dataTypeName = t.Name dataTypeName = t.Name
} else { } else {
dataTypeName = "unknown type" dataTypeName = "unknown type"
} }
return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %s", value, format, dataTypeName, oid, err) return fmt.Errorf("unable to encode %#v into %s format for %s (OID %d): %w", value, format, dataTypeName, oid, err)
} }
// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return

223
vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go generated vendored Normal file
View File

@ -0,0 +1,223 @@
package pgtype
import (
"net"
"net/netip"
"reflect"
"sync"
"time"
)
var (
// defaultMap contains default mappings between PostgreSQL server types and Go type handling logic.
defaultMap *Map
defaultMapInitOnce = sync.Once{}
)
func initDefaultMap() {
defaultMap = &Map{
oidToType: make(map[uint32]*Type),
nameToType: make(map[string]*Type),
reflectTypeToName: make(map[reflect.Type]string),
oidToFormatCode: make(map[uint32]int16),
memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan),
memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan),
TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{
TryWrapDerefPointerEncodePlan,
TryWrapBuiltinTypeEncodePlan,
TryWrapFindUnderlyingTypeEncodePlan,
TryWrapStructEncodePlan,
TryWrapSliceEncodePlan,
TryWrapMultiDimSliceEncodePlan,
TryWrapArrayEncodePlan,
},
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
TryPointerPointerScanPlan,
TryWrapBuiltinTypeScanPlan,
TryFindUnderlyingTypeScanPlan,
TryWrapStructScanPlan,
TryWrapPtrSliceScanPlan,
TryWrapPtrMultiDimSliceScanPlan,
TryWrapPtrArrayScanPlan,
},
}
// Base types
defaultMap.RegisterType(&Type{Name: "aclitem", OID: ACLItemOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "bit", OID: BitOID, Codec: BitsCodec{}})
defaultMap.RegisterType(&Type{Name: "bool", OID: BoolOID, Codec: BoolCodec{}})
defaultMap.RegisterType(&Type{Name: "box", OID: BoxOID, Codec: BoxCodec{}})
defaultMap.RegisterType(&Type{Name: "bpchar", OID: BPCharOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "bytea", OID: ByteaOID, Codec: ByteaCodec{}})
defaultMap.RegisterType(&Type{Name: "char", OID: QCharOID, Codec: QCharCodec{}})
defaultMap.RegisterType(&Type{Name: "cid", OID: CIDOID, Codec: Uint32Codec{}})
defaultMap.RegisterType(&Type{Name: "cidr", OID: CIDROID, Codec: InetCodec{}})
defaultMap.RegisterType(&Type{Name: "circle", OID: CircleOID, Codec: CircleCodec{}})
defaultMap.RegisterType(&Type{Name: "date", OID: DateOID, Codec: DateCodec{}})
defaultMap.RegisterType(&Type{Name: "float4", OID: Float4OID, Codec: Float4Codec{}})
defaultMap.RegisterType(&Type{Name: "float8", OID: Float8OID, Codec: Float8Codec{}})
defaultMap.RegisterType(&Type{Name: "inet", OID: InetOID, Codec: InetCodec{}})
defaultMap.RegisterType(&Type{Name: "int2", OID: Int2OID, Codec: Int2Codec{}})
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}})
defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}})
defaultMap.RegisterType(&Type{Name: "oid", OID: OIDOID, Codec: Uint32Codec{}})
defaultMap.RegisterType(&Type{Name: "path", OID: PathOID, Codec: PathCodec{}})
defaultMap.RegisterType(&Type{Name: "point", OID: PointOID, Codec: PointCodec{}})
defaultMap.RegisterType(&Type{Name: "polygon", OID: PolygonOID, Codec: PolygonCodec{}})
defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}})
defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}})
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}})
defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}})
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
// Range types
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})
defaultMap.RegisterType(&Type{Name: "int4range", OID: Int4rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int4OID]}})
defaultMap.RegisterType(&Type{Name: "int8range", OID: Int8rangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[Int8OID]}})
defaultMap.RegisterType(&Type{Name: "numrange", OID: NumrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[NumericOID]}})
defaultMap.RegisterType(&Type{Name: "tsrange", OID: TsrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestampOID]}})
defaultMap.RegisterType(&Type{Name: "tstzrange", OID: TstzrangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}})
// Multirange types
defaultMap.RegisterType(&Type{Name: "datemultirange", OID: DatemultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[DaterangeOID]}})
defaultMap.RegisterType(&Type{Name: "int4multirange", OID: Int4multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}})
defaultMap.RegisterType(&Type{Name: "int8multirange", OID: Int8multirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}})
defaultMap.RegisterType(&Type{Name: "nummultirange", OID: NummultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[NumrangeOID]}})
defaultMap.RegisterType(&Type{Name: "tsmultirange", OID: TsmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TsrangeOID]}})
defaultMap.RegisterType(&Type{Name: "tstzmultirange", OID: TstzmultirangeOID, Codec: &MultirangeCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}})
// Array types
defaultMap.RegisterType(&Type{Name: "_aclitem", OID: ACLItemArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ACLItemOID]}})
defaultMap.RegisterType(&Type{Name: "_bit", OID: BitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BitOID]}})
defaultMap.RegisterType(&Type{Name: "_bool", OID: BoolArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoolOID]}})
defaultMap.RegisterType(&Type{Name: "_box", OID: BoxArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BoxOID]}})
defaultMap.RegisterType(&Type{Name: "_bpchar", OID: BPCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[BPCharOID]}})
defaultMap.RegisterType(&Type{Name: "_bytea", OID: ByteaArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[ByteaOID]}})
defaultMap.RegisterType(&Type{Name: "_char", OID: QCharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[QCharOID]}})
defaultMap.RegisterType(&Type{Name: "_cid", OID: CIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDOID]}})
defaultMap.RegisterType(&Type{Name: "_cidr", OID: CIDRArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CIDROID]}})
defaultMap.RegisterType(&Type{Name: "_circle", OID: CircleArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[CircleOID]}})
defaultMap.RegisterType(&Type{Name: "_date", OID: DateArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DateOID]}})
defaultMap.RegisterType(&Type{Name: "_daterange", OID: DaterangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[DaterangeOID]}})
defaultMap.RegisterType(&Type{Name: "_float4", OID: Float4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float4OID]}})
defaultMap.RegisterType(&Type{Name: "_float8", OID: Float8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Float8OID]}})
defaultMap.RegisterType(&Type{Name: "_inet", OID: InetArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[InetOID]}})
defaultMap.RegisterType(&Type{Name: "_int2", OID: Int2ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int2OID]}})
defaultMap.RegisterType(&Type{Name: "_int4", OID: Int4ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4OID]}})
defaultMap.RegisterType(&Type{Name: "_int4range", OID: Int4rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int4rangeOID]}})
defaultMap.RegisterType(&Type{Name: "_int8", OID: Int8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8OID]}})
defaultMap.RegisterType(&Type{Name: "_int8range", OID: Int8rangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[Int8rangeOID]}})
defaultMap.RegisterType(&Type{Name: "_interval", OID: IntervalArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[IntervalOID]}})
defaultMap.RegisterType(&Type{Name: "_json", OID: JSONArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONOID]}})
defaultMap.RegisterType(&Type{Name: "_jsonb", OID: JSONBArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONBOID]}})
defaultMap.RegisterType(&Type{Name: "_jsonpath", OID: JSONPathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[JSONPathOID]}})
defaultMap.RegisterType(&Type{Name: "_line", OID: LineArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LineOID]}})
defaultMap.RegisterType(&Type{Name: "_lseg", OID: LsegArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[LsegOID]}})
defaultMap.RegisterType(&Type{Name: "_macaddr", OID: MacaddrArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[MacaddrOID]}})
defaultMap.RegisterType(&Type{Name: "_name", OID: NameArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NameOID]}})
defaultMap.RegisterType(&Type{Name: "_numeric", OID: NumericArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumericOID]}})
defaultMap.RegisterType(&Type{Name: "_numrange", OID: NumrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[NumrangeOID]}})
defaultMap.RegisterType(&Type{Name: "_oid", OID: OIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[OIDOID]}})
defaultMap.RegisterType(&Type{Name: "_path", OID: PathArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PathOID]}})
defaultMap.RegisterType(&Type{Name: "_point", OID: PointArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PointOID]}})
defaultMap.RegisterType(&Type{Name: "_polygon", OID: PolygonArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[PolygonOID]}})
defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}})
defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}})
defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}})
defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}})
defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}})
defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}})
defaultMap.RegisterType(&Type{Name: "_tsrange", OID: TsrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TsrangeOID]}})
defaultMap.RegisterType(&Type{Name: "_tstzrange", OID: TstzrangeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TstzrangeOID]}})
defaultMap.RegisterType(&Type{Name: "_uuid", OID: UUIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[UUIDOID]}})
defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}})
defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}})
defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}})
// Integer types that directly map to a PostgreSQL type
registerDefaultPgTypeVariants[int16](defaultMap, "int2")
registerDefaultPgTypeVariants[int32](defaultMap, "int4")
registerDefaultPgTypeVariants[int64](defaultMap, "int8")
// Integer types that do not have a direct match to a PostgreSQL type
registerDefaultPgTypeVariants[int8](defaultMap, "int8")
registerDefaultPgTypeVariants[int](defaultMap, "int8")
registerDefaultPgTypeVariants[uint8](defaultMap, "int8")
registerDefaultPgTypeVariants[uint16](defaultMap, "int8")
registerDefaultPgTypeVariants[uint32](defaultMap, "int8")
registerDefaultPgTypeVariants[uint64](defaultMap, "numeric")
registerDefaultPgTypeVariants[uint](defaultMap, "numeric")
registerDefaultPgTypeVariants[float32](defaultMap, "float4")
registerDefaultPgTypeVariants[float64](defaultMap, "float8")
registerDefaultPgTypeVariants[bool](defaultMap, "bool")
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
registerDefaultPgTypeVariants[string](defaultMap, "text")
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")
registerDefaultPgTypeVariants[net.IPNet](defaultMap, "cidr")
registerDefaultPgTypeVariants[netip.Addr](defaultMap, "inet")
registerDefaultPgTypeVariants[netip.Prefix](defaultMap, "cidr")
// pgtype provided structs
registerDefaultPgTypeVariants[Bits](defaultMap, "varbit")
registerDefaultPgTypeVariants[Bool](defaultMap, "bool")
registerDefaultPgTypeVariants[Box](defaultMap, "box")
registerDefaultPgTypeVariants[Circle](defaultMap, "circle")
registerDefaultPgTypeVariants[Date](defaultMap, "date")
registerDefaultPgTypeVariants[Range[Date]](defaultMap, "daterange")
registerDefaultPgTypeVariants[Multirange[Range[Date]]](defaultMap, "datemultirange")
registerDefaultPgTypeVariants[Float4](defaultMap, "float4")
registerDefaultPgTypeVariants[Float8](defaultMap, "float8")
registerDefaultPgTypeVariants[Range[Float8]](defaultMap, "numrange") // There is no PostgreSQL builtin float8range so map it to numrange.
registerDefaultPgTypeVariants[Multirange[Range[Float8]]](defaultMap, "nummultirange") // There is no PostgreSQL builtin float8multirange so map it to nummultirange.
registerDefaultPgTypeVariants[Int2](defaultMap, "int2")
registerDefaultPgTypeVariants[Int4](defaultMap, "int4")
registerDefaultPgTypeVariants[Range[Int4]](defaultMap, "int4range")
registerDefaultPgTypeVariants[Multirange[Range[Int4]]](defaultMap, "int4multirange")
registerDefaultPgTypeVariants[Int8](defaultMap, "int8")
registerDefaultPgTypeVariants[Range[Int8]](defaultMap, "int8range")
registerDefaultPgTypeVariants[Multirange[Range[Int8]]](defaultMap, "int8multirange")
registerDefaultPgTypeVariants[Interval](defaultMap, "interval")
registerDefaultPgTypeVariants[Line](defaultMap, "line")
registerDefaultPgTypeVariants[Lseg](defaultMap, "lseg")
registerDefaultPgTypeVariants[Numeric](defaultMap, "numeric")
registerDefaultPgTypeVariants[Range[Numeric]](defaultMap, "numrange")
registerDefaultPgTypeVariants[Multirange[Range[Numeric]]](defaultMap, "nummultirange")
registerDefaultPgTypeVariants[Path](defaultMap, "path")
registerDefaultPgTypeVariants[Point](defaultMap, "point")
registerDefaultPgTypeVariants[Polygon](defaultMap, "polygon")
registerDefaultPgTypeVariants[TID](defaultMap, "tid")
registerDefaultPgTypeVariants[Text](defaultMap, "text")
registerDefaultPgTypeVariants[Time](defaultMap, "time")
registerDefaultPgTypeVariants[Timestamp](defaultMap, "timestamp")
registerDefaultPgTypeVariants[Timestamptz](defaultMap, "timestamptz")
registerDefaultPgTypeVariants[Range[Timestamp]](defaultMap, "tsrange")
registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange")
registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange")
registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange")
registerDefaultPgTypeVariants[UUID](defaultMap, "uuid")
defaultMap.buildReflectTypeToType()
}

View File

@ -40,7 +40,7 @@ func (p Point) PointValue() (Point, error) {
} }
func parsePoint(src []byte) (*Point, error) { func parsePoint(src []byte) (*Point, error) {
if src == nil || bytes.Compare(src, []byte("null")) == 0 { if src == nil || bytes.Equal(src, []byte("null")) {
return &Point{}, nil return &Point{}, nil
} }

View File

@ -22,7 +22,7 @@ type TIDValuer interface {
// //
// When one does // When one does
// //
// select ctid, * from some_table; // select ctid, * from some_table;
// //
// it is the data type of the ctid hidden system column. // it is the data type of the ctid hidden system column.
// //

View File

@ -3,6 +3,7 @@ package pgtype
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -66,6 +67,55 @@ func (ts Timestamp) Value() (driver.Value, error) {
return ts.Time, nil return ts.Time, nil
} }
func (ts Timestamp) MarshalJSON() ([]byte, error) {
if !ts.Valid {
return []byte("null"), nil
}
var s string
switch ts.InfinityModifier {
case Finite:
s = ts.Time.Format(time.RFC3339Nano)
case Infinity:
s = "infinity"
case NegativeInfinity:
s = "-infinity"
}
return json.Marshal(s)
}
func (ts *Timestamp) UnmarshalJSON(b []byte) error {
var s *string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
if s == nil {
*ts = Timestamp{}
return nil
}
switch *s {
case "infinity":
*ts = Timestamp{Valid: true, InfinityModifier: Infinity}
case "-infinity":
*ts = Timestamp{Valid: true, InfinityModifier: -Infinity}
default:
// PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz
tim, err := time.Parse(time.RFC3339Nano, *s)
if err != nil {
return err
}
*ts = Timestamp{Time: tim, Valid: true}
}
return nil
}
type TimestampCodec struct{} type TimestampCodec struct{}
func (TimestampCodec) FormatSupported(format int16) bool { func (TimestampCodec) FormatSupported(format int16) bool {

View File

@ -97,7 +97,7 @@ func (src UUID) MarshalJSON() ([]byte, error) {
} }
func (dst *UUID) UnmarshalJSON(src []byte) error { func (dst *UUID) UnmarshalJSON(src []byte) error {
if bytes.Compare(src, []byte("null")) == 0 { if bytes.Equal(src, []byte("null")) {
*dst = UUID{} *dst = UUID{}
return nil return nil
} }

View File

@ -28,12 +28,16 @@ type Rows interface {
// to call Close after rows is already closed. // to call Close after rows is already closed.
Close() Close()
// Err returns any error that occurred while reading. // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by
// calling Close or by Next returning false). If it is called early it may return nil even if there was an error
// executing the query.
Err() error Err() error
// CommandTag returns the command tag from this query. It is only available after Rows is closed. // CommandTag returns the command tag from this query. It is only available after Rows is closed.
CommandTag() pgconn.CommandTag CommandTag() pgconn.CommandTag
// FieldDescriptions returns the field descriptions of the columns. It may return nil. In particular this can occur
// when there was an error executing the query.
FieldDescriptions() []pgconn.FieldDescription FieldDescriptions() []pgconn.FieldDescription
// Next prepares the next row for reading. It returns true if there is another // Next prepares the next row for reading. It returns true if there is another
@ -227,7 +231,11 @@ func (rows *baseRows) Scan(dest ...any) error {
if len(dest) == 1 { if len(dest) == 1 {
if rc, ok := dest[0].(RowScanner); ok { if rc, ok := dest[0].(RowScanner); ok {
return rc.ScanRow(rows) err := rc.ScanRow(rows)
if err != nil {
rows.fatal(err)
}
return err
} }
} }
@ -298,7 +306,7 @@ func (rows *baseRows) Values() ([]any, error) {
copy(newBuf, buf) copy(newBuf, buf)
values = append(values, newBuf) values = append(values, newBuf)
default: default:
rows.fatal(errors.New("Unknown format code")) rows.fatal(errors.New("unknown format code"))
} }
} }
@ -488,7 +496,8 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
} }
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
// has fields. The row and T fields will by matched by position. // has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then the field will be
// ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) { func RowToStructByPos[T any](row CollectableRow) (T, error) {
var value T var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
@ -496,7 +505,8 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
} }
// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a // RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
// public fields as row has fields. The row and T fields will by matched by position. // public fields as row has fields. The row and T fields will by matched by position. If the "db" struct tag is "-" then
// the field will be ignored.
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
var value T var value T
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
@ -533,13 +543,16 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
for i := 0; i < dstElemType.NumField(); i++ { for i := 0; i < dstElemType.NumField(); i++ {
sf := dstElemType.Field(i) sf := dstElemType.Field(i)
if sf.PkgPath == "" { // Handle anonymous struct embedding, but do not try to handle embedded pointers.
// Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
if sf.Anonymous && sf.Type.Kind() == reflect.Struct { scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) } else if sf.PkgPath == "" {
} else { dbTag, _ := sf.Tag.Lookup(structTagKey)
scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) if dbTag == "-" {
// Field is ignored, skip it.
continue
} }
scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
} }
} }
@ -565,8 +578,28 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
return &value, err return &value, err
} }
// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public
// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
return value, err
}
// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or
// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is
// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-"
// then the field will be ignored.
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
var value T
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
return &value, err
}
type namedStructRowScanner struct { type namedStructRowScanner struct {
ptrToStruct any ptrToStruct any
lax bool
} }
func (rs *namedStructRowScanner) ScanRow(rows Rows) error { func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
@ -578,7 +611,6 @@ func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
dstElemValue := dstValue.Elem() dstElemValue := dstValue.Elem()
scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
if err != nil { if err != nil {
return err return err
} }
@ -638,7 +670,13 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s
colName = sf.Name colName = sf.Name
} }
fpos := fieldPosByName(fldDescs, colName) fpos := fieldPosByName(fldDescs, colName)
if fpos == -1 || fpos >= len(scanTargets) { if fpos == -1 {
if rs.lax {
continue
}
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
}
if fpos >= len(scanTargets) && !rs.lax {
return nil, fmt.Errorf("cannot find field %s in returned row", colName) return nil, fmt.Errorf("cannot find field %s in returned row", colName)
} }
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()

View File

@ -2,58 +2,58 @@
// //
// A database/sql connection can be established through sql.Open. // A database/sql connection can be established through sql.Open.
// //
// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
// if err != nil { // if err != nil {
// return err // return err
// } // }
// //
// Or from a DSN string. // Or from a DSN string.
// //
// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
// if err != nil { // if err != nil {
// return err // return err
// } // }
// //
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
// with sql.Open. // with sql.Open.
// //
// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) // connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
// connConfig.Logger = myLogger // connConfig.Logger = myLogger
// connStr := stdlib.RegisterConnConfig(connConfig) // connStr := stdlib.RegisterConnConfig(connConfig)
// db, _ := sql.Open("pgx", connStr) // db, _ := sql.Open("pgx", connStr)
// //
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters. // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. It does not support named parameters.
// //
// db.QueryRow("select * from users where id=$1", userID) // db.QueryRow("select * from users where id=$1", userID)
// //
// (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows // (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard database/sql.DB connection pool. This allows
// operations that use pgx specific functionality. // operations that use pgx specific functionality.
// //
// // Given db is a *sql.DB // // Given db is a *sql.DB
// conn, err := db.Conn(context.Background()) // conn, err := db.Conn(context.Background())
// if err != nil { // if err != nil {
// // handle error from acquiring connection from DB pool // // handle error from acquiring connection from DB pool
// } // }
// //
// err = conn.Raw(func(driverConn any) error { // err = conn.Raw(func(driverConn any) error {
// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn // conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn
// // Do pgx specific stuff with conn // // Do pgx specific stuff with conn
// conn.CopyFrom(...) // conn.CopyFrom(...)
// return nil // return nil
// }) // })
// if err != nil { // if err != nil {
// // handle error that occurred while using *pgx.Conn // // handle error that occurred while using *pgx.Conn
// } // }
// //
// PostgreSQL Specific Data Types // # PostgreSQL Specific Data Types
// //
// The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes // The pgtype package provides support for PostgreSQL specific types. *pgtype.Map.SQLScanner is an adapter that makes
// these types usable as a sql.Scanner. // these types usable as a sql.Scanner.
// //
// m := pgtype.NewMap() // m := pgtype.NewMap()
// var a []int64 // var a []int64
// err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) // err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a))
package stdlib package stdlib
import ( import (
@ -85,7 +85,13 @@ func init() {
pgxDriver = &Driver{ pgxDriver = &Driver{
configs: make(map[string]*pgx.ConnConfig), configs: make(map[string]*pgx.ConnConfig),
} }
sql.Register("pgx", pgxDriver)
// if pgx driver was already registered by different pgx major version then we
// skip registration under the default name.
if !contains(sql.Drivers(), "pgx") {
sql.Register("pgx", pgxDriver)
}
sql.Register("pgx/v5", pgxDriver)
databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
pgtype.BoolOID: 1, pgtype.BoolOID: 1,
@ -104,6 +110,17 @@ func init() {
} }
} }
// TODO replace by slices.Contains when experimental package will be merged to stdlib
// https://pkg.go.dev/golang.org/x/exp/slices#Contains
func contains(list []string, y string) bool {
for _, x := range list {
if x == y {
return true
}
}
return false
}
// OptionOpenDB options for configuring the driver when opening a new db pool. // OptionOpenDB options for configuring the driver when opening a new db pool.
type OptionOpenDB func(*connector) type OptionOpenDB func(*connector)
@ -140,7 +157,7 @@ func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) err
return nil return nil
} }
newFallbacks := append([]*pgconn.FallbackConfig{&pgconn.FallbackConfig{ newFallbacks := append([]*pgconn.FallbackConfig{{
Host: connConfig.Host, Host: connConfig.Host,
Port: connConfig.Port, Port: connConfig.Port,
TLSConfig: connConfig.TLSConfig, TLSConfig: connConfig.TLSConfig,

View File

@ -44,6 +44,10 @@ type TxOptions struct {
IsoLevel TxIsoLevel IsoLevel TxIsoLevel
AccessMode TxAccessMode AccessMode TxAccessMode
DeferrableMode TxDeferrableMode DeferrableMode TxDeferrableMode
// BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax
// such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings.
BeginQuery string
} }
var emptyTxOptions TxOptions var emptyTxOptions TxOptions
@ -53,6 +57,10 @@ func (txOptions TxOptions) beginSQL() string {
return "begin" return "begin"
} }
if txOptions.BeginQuery != "" {
return txOptions.BeginQuery
}
var buf strings.Builder var buf strings.Builder
buf.Grow(64) // 64 - maximum length of string with available options buf.Grow(64) // 64 - maximum length of string with available options
buf.WriteString("begin") buf.WriteString("begin")
@ -144,7 +152,6 @@ type Tx interface {
// called on the dbTx. // called on the dbTx.
type dbTx struct { type dbTx struct {
conn *Conn conn *Conn
err error
savepointNum int64 savepointNum int64
closed bool closed bool
} }

1
vendor/gorm.io/driver/postgres/.gitignore generated vendored Normal file
View File

@ -0,0 +1 @@
.idea

48
vendor/gorm.io/driver/postgres/error_translator.go generated vendored Normal file
View File

@ -0,0 +1,48 @@
package postgres
import (
"encoding/json"
"gorm.io/gorm"
"github.com/jackc/pgx/v5/pgconn"
)
var errCodes = map[string]error{
"23505": gorm.ErrDuplicatedKey,
"23503": gorm.ErrForeignKeyViolated,
"42703": gorm.ErrInvalidField,
}
type ErrMessage struct {
Code string
Severity string
Message string
}
// Translate it will translate the error to native gorm errors.
// Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback.
func (dialector Dialector) Translate(err error) error {
if pgErr, ok := err.(*pgconn.PgError); ok {
if translatedErr, found := errCodes[pgErr.Code]; found {
return translatedErr
}
return err
}
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var errMsg ErrMessage
unmarshalErr := json.Unmarshal(parsedErr, &errMsg)
if unmarshalErr != nil {
return err
}
if translatedErr, found := errCodes[errMsg.Code]; found {
return translatedErr
}
return err
}

View File

@ -13,44 +13,61 @@ import (
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
// See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql
// Here are some changes:
// - use `LEFT JOIN` instead of `CROSS JOIN`
// - exclude indexes used to support constraints (they are auto-generated)
const indexSql = ` const indexSql = `
select SELECT
t.relname as table_name, ct.relname AS table_name,
i.relname as index_name, ci.relname AS index_name,
a.attname as column_name, i.indisunique AS non_unique,
ix.indisunique as non_unique, i.indisprimary AS primary,
ix.indisprimary as primary a.attname AS column_name
from FROM
pg_class t, pg_index i
pg_class i, LEFT JOIN pg_class ct ON ct.oid = i.indrelid
pg_index ix, LEFT JOIN pg_class ci ON ci.oid = i.indexrelid
pg_attribute a LEFT JOIN pg_attribute a ON a.attrelid = ct.oid
where LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid
t.oid = ix.indrelid WHERE
and i.oid = ix.indexrelid a.attnum = ANY(i.indkey)
and a.attrelid = t.oid AND con.oid IS NULL
and a.attnum = ANY(ix.indkey) AND ct.relkind = 'r'
and t.relkind = 'r' AND ct.relname = ?
and t.relname = ?
` `
var typeAliasMap = map[string][]string{ var typeAliasMap = map[string][]string{
"int2": {"smallint"}, "int2": {"smallint"},
"int4": {"integer"}, "int4": {"integer"},
"int8": {"bigint"}, "int8": {"bigint"},
"smallint": {"int2"}, "smallint": {"int2"},
"integer": {"int4"}, "integer": {"int4"},
"bigint": {"int8"}, "bigint": {"int8"},
"decimal": {"numeric"}, "decimal": {"numeric"},
"numeric": {"decimal"}, "numeric": {"decimal"},
"timestamptz": {"timestamp with time zone"},
"timestamp with time zone": {"timestamptz"},
"bool": {"boolean"},
"boolean": {"bool"},
} }
type Migrator struct { type Migrator struct {
migrator.Migrator migrator.Migrator
} }
// select querys ignore dryrun
func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) {
queryTx := m.DB
if m.DB.DryRun {
queryTx = m.DB.Session(&gorm.Session{})
queryTx.DryRun = false
}
return queryTx.Raw(sql, values...)
}
func (m Migrator) CurrentDatabase() (name string) { func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
return return
} }
@ -76,11 +93,13 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
func (m Migrator) HasIndex(value interface{}, name string) bool { func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw( return m.queryRaw(
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
).Scan(&count).Error ).Scan(&count).Error
}) })
@ -90,33 +109,35 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt) if idx := stmt.Schema.LookIndex(name); idx != nil {
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE " createIndexSQL := "CREATE "
if idx.Class != "" { if idx.Class != "" {
createIndexSQL += idx.Class + " " createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX "
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}
createIndexSQL += "IF NOT EXISTS ? ON ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
} }
createIndexSQL += "INDEX "
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}
createIndexSQL += "IF NOT EXISTS ? ON ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
} }
return fmt.Errorf("failed to create index with name %v", name) return fmt.Errorf("failed to create index with name %v", name)
@ -134,8 +155,10 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
@ -144,7 +167,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
func (m Migrator) GetTables() (tableList []string, err error) { func (m Migrator) GetTables() (tableList []string, err error) {
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
} }
func (m Migrator) CreateTable(values ...interface{}) (err error) { func (m Migrator) CreateTable(values ...interface{}) (err error) {
@ -153,13 +176,16 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
} }
for _, value := range m.ReorderModels(values, false) { for _, value := range m.ReorderModels(values, false) {
if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, field := range stmt.Schema.FieldsByDBName { if stmt.Schema != nil {
if field.Comment != "" { for _, fieldName := range stmt.Schema.DBNames {
if err := m.DB.Exec( field := stmt.Schema.FieldsByDBName[fieldName]
"COMMENT ON COLUMN ?.? IS ?", if field.Comment != "" {
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), if err := m.DB.Exec(
).Error; err != nil { "COMMENT ON COLUMN ?.? IS ?",
return err m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
} }
} }
} }
@ -175,7 +201,7 @@ func (m Migrator) HasTable(value interface{}) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
}) })
return count > 0 return count > 0
} }
@ -200,13 +226,15 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
m.resetPreparedStmts() m.resetPreparedStmts()
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
if field.Comment != "" { if field := stmt.Schema.LookUpField(field); field != nil {
if err := m.DB.Exec( if field.Comment != "" {
"COMMENT ON COLUMN ?.? IS ?", if err := m.DB.Exec(
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), "COMMENT ON COLUMN ?.? IS ?",
).Error; err != nil { m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
return err ).Error; err != nil {
return err
}
} }
} }
} }
@ -225,7 +253,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
} }
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.DB.Raw( return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentSchema, curTable, name, currentSchema, curTable, name,
).Scan(&count).Error ).Scan(&count).Error
@ -250,7 +278,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
m.DB.Raw(checkSQL, values...).Scan(&description) m.queryRaw(checkSQL, values...).Scan(&description)
comment := strings.Trim(field.Comment, "'") comment := strings.Trim(field.Comment, "'")
comment = strings.Trim(comment, `"`) comment = strings.Trim(comment, `"`)
@ -269,101 +297,92 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
// AlterColumn alter value's `field` column' type based on schema definition // AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error { err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
var ( if field := stmt.Schema.LookUpField(field); field != nil {
columnTypes, _ = m.DB.Migrator().ColumnTypes(value) var (
fieldColumnType *migrator.ColumnType columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
) fieldColumnType *migrator.ColumnType
for _, columnType := range columnTypes { )
if columnType.Name() == field.DBName { for _, columnType := range columnTypes {
fieldColumnType, _ = columnType.(*migrator.ColumnType) if columnType.Name() == field.DBName {
} fieldColumnType, _ = columnType.(*migrator.ColumnType)
}
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
// check for typeName and SQL name
isSameType := true
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
isSameType = false
// if different, also check for aliases
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
for _, alias := range aliases {
if strings.HasPrefix(fileType.SQL, alias) {
isSameType = true
break
} }
} }
}
// not same, migrate fileType := clause.Expr{SQL: m.DataTypeOf(field)}
if !isSameType { // check for typeName and SQL name
filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() isSameType := true
if field.AutoIncrement && filedColumnAutoIncrement { // update if fieldColumnType.DatabaseTypeName() != fileType.SQL {
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) isSameType = false
if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { // if different, also check for aliases
if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
return err for _, alias := range aliases {
if strings.HasPrefix(fileType.SQL, alias) {
isSameType = true
break
} }
} }
} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
} }
}
if null, _ := fieldColumnType.Nullable(); null == field.NotNull { // not same, migrate
if field.NotNull { if !isSameType {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
return err if field.AutoIncrement && filedColumnAutoIncrement { // update
} serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
} else { if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err return err
} }
} }
} } else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if uniq, _ := fieldColumnType.Unique(); !uniq && field.Unique { if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)}
// Not a unique constraint but a unique index
if !m.HasIndex(stmt.Table, idxName.Name) {
if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}
if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err return err
} }
} else if field.DefaultValue != "(-)" { } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
return err return err
} }
} else { } else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil {
return err return err
} }
} }
} }
if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
if field.NotNull {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}
if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err
}
} else if field.DefaultValue != "(-)" {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
}
}
}
return nil
} }
return nil
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -375,18 +394,39 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
return nil return nil
} }
func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error {
alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?"
isUncastableDefaultValue := false
if targetType.SQL == "boolean" {
switch existingColumn.DatabaseTypeName() {
case "int2", "int8", "numeric":
alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?"
}
isUncastableDefaultValue = true
}
if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil {
return err
}
return nil
}
func (m Migrator) HasConstraint(value interface{}, name string) bool { func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name) constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
currentSchema, curTable := m.CurrentSchema(stmt, table)
if constraint != nil { if constraint != nil {
name = constraint.Name name = constraint.GetName()
} else if chk != nil {
name = chk.Name
} }
currentSchema, curTable := m.CurrentSchema(stmt, table)
return m.DB.Raw( return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
currentSchema, curTable, name, currentSchema, curTable, name,
).Scan(&count).Error ).Scan(&count).Error
@ -401,7 +441,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
var ( var (
currentDatabase = m.DB.Migrator().CurrentDatabase() currentDatabase = m.DB.Migrator().CurrentDatabase()
currentSchema, table = m.CurrentSchema(stmt, stmt.Table) currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
columns, err = m.DB.Raw( columns, err = m.queryRaw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows() currentDatabase, currentSchema, table).Rows()
) )
@ -441,7 +481,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
} }
if column.DefaultValueValue.Valid { if column.DefaultValueValue.Valid {
column.DefaultValueValue.String = regexp.MustCompile(`'?(.*)\b'?:+[\w\s]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String)
} }
if datetimePrecision.Valid { if datetimePrecision.Valid {
@ -475,7 +515,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
// check primary, unique field // check primary, unique field
{ {
columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
if err != nil { if err != nil {
return err return err
} }
@ -487,7 +527,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
} }
columnTypeRows.Close() columnTypeRows.Close()
columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil { if err != nil {
return err return err
} }
@ -514,7 +554,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
// check column type // check column type
{ {
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns AND NOT a.attisdropped -- hide deleted columns
@ -672,7 +712,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error { err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0) result := make([]*Index, 0)
scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
if scanErr != nil { if scanErr != nil {
return scanErr return scanErr
} }
@ -747,3 +787,8 @@ func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error {
m.resetPreparedStmts() m.resetPreparedStmts()
return nil return nil
} }
func parseDefaultValueValue(defaultValue string) string {
value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1")
return strings.Trim(value, "'")
}

View File

@ -24,11 +24,17 @@ type Dialector struct {
type Config struct { type Config struct {
DriverName string DriverName string
DSN string DSN string
WithoutQuotingCheck bool
PreferSimpleProtocol bool PreferSimpleProtocol bool
WithoutReturning bool WithoutReturning bool
Conn gorm.ConnPool Conn gorm.ConnPool
} }
var (
timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
defaultIdentifierLength = 63 //maximum identifier length for postgres
)
func Open(dsn string) gorm.Dialector { func Open(dsn string) gorm.Dialector {
return &Dialector{&Config{DSN: dsn}} return &Dialector{&Config{DSN: dsn}}
} }
@ -41,17 +47,42 @@ func (dialector Dialector) Name() string {
return "postgres" return "postgres"
} }
var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{
IdentifierMaxLength: defaultIdentifierLength,
}
return nil
}
switch v := config.NamingStrategy.(type) {
case *schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
}
case schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
config.NamingStrategy = v
}
}
return nil
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) { func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
callbackConfig := &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
}
// register callbacks // register callbacks
if !dialector.WithoutReturning { if !dialector.WithoutReturning {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
})
} }
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
if dialector.Conn != nil { if dialector.Conn != nil {
db.ConnPool = dialector.Conn db.ConnPool = dialector.Conn
@ -90,10 +121,23 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('$') writer.WriteByte('$')
writer.WriteString(strconv.Itoa(len(stmt.Vars))) index := 0
varLen := len(stmt.Vars)
if varLen > 0 {
switch stmt.Vars[0].(type) {
case pgx.QueryExecMode:
index++
}
}
writer.WriteString(strconv.Itoa(varLen - index))
} }
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
if dialector.WithoutQuotingCheck {
writer.WriteString(str)
return
}
var ( var (
underQuoted, selfQuoted bool underQuoted, selfQuoted bool
continuousBacktick int8 continuousBacktick int8

12
vendor/modules.txt vendored
View File

@ -83,8 +83,6 @@ github.com/go-playground/universal-translator
# github.com/go-playground/validator/v10 v10.25.0 # github.com/go-playground/validator/v10 v10.25.0
## explicit; go 1.20 ## explicit; go 1.20
github.com/go-playground/validator/v10 github.com/go-playground/validator/v10
# github.com/golang-jwt/jwt v3.2.2+incompatible
## explicit
# github.com/golang-jwt/jwt/v5 v5.2.1 # github.com/golang-jwt/jwt/v5 v5.2.1
## explicit; go 1.18 ## explicit; go 1.18
github.com/golang-jwt/jwt/v5 github.com/golang-jwt/jwt/v5
@ -109,16 +107,16 @@ github.com/jackc/pgpassfile
# github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 # github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
## explicit; go 1.14 ## explicit; go 1.14
github.com/jackc/pgservicefile github.com/jackc/pgservicefile
# github.com/jackc/pgx/v5 v5.2.0 # github.com/jackc/pgx/v5 v5.4.3
## explicit; go 1.18 ## explicit; go 1.19
github.com/jackc/pgx/v5 github.com/jackc/pgx/v5
github.com/jackc/pgx/v5/internal/anynil github.com/jackc/pgx/v5/internal/anynil
github.com/jackc/pgx/v5/internal/iobufpool github.com/jackc/pgx/v5/internal/iobufpool
github.com/jackc/pgx/v5/internal/nbconn
github.com/jackc/pgx/v5/internal/pgio github.com/jackc/pgx/v5/internal/pgio
github.com/jackc/pgx/v5/internal/sanitize github.com/jackc/pgx/v5/internal/sanitize
github.com/jackc/pgx/v5/internal/stmtcache github.com/jackc/pgx/v5/internal/stmtcache
github.com/jackc/pgx/v5/pgconn github.com/jackc/pgx/v5/pgconn
github.com/jackc/pgx/v5/pgconn/internal/bgreader
github.com/jackc/pgx/v5/pgconn/internal/ctxwatch github.com/jackc/pgx/v5/pgconn/internal/ctxwatch
github.com/jackc/pgx/v5/pgproto3 github.com/jackc/pgx/v5/pgproto3
github.com/jackc/pgx/v5/pgtype github.com/jackc/pgx/v5/pgtype
@ -284,8 +282,8 @@ gopkg.in/ini.v1
# gopkg.in/yaml.v3 v3.0.1 # gopkg.in/yaml.v3 v3.0.1
## explicit ## explicit
gopkg.in/yaml.v3 gopkg.in/yaml.v3
# gorm.io/driver/postgres v1.4.7 # gorm.io/driver/postgres v1.5.7
## explicit; go 1.14 ## explicit; go 1.18
gorm.io/driver/postgres gorm.io/driver/postgres
# gorm.io/gorm v1.25.12 # gorm.io/gorm v1.25.12
## explicit; go 1.18 ## explicit; go 1.18