VYPR
High severityNVD Advisory· Published Mar 6, 2024· Updated Dec 12, 2024

pgx SQL Injection via Protocol Message Size Overflow

CVE-2024-27304

Description

pgx is a PostgreSQL driver and toolkit for Go. SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer overflow in the calculated message size can cause the one large message to be sent as multiple messages under the attacker's control. The problem is resolved in v4.18.2 and v5.5.4. As a workaround, reject user input large enough to cause a single query or bind message to exceed 4 GB in size.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
github.com/jackc/pgxGo
< 4.18.24.18.2
github.com/jackc/pgxGo
>= 5.0.0, < 5.5.45.5.4
github.com/jackc/pgx/v4Go
< 4.18.24.18.2
github.com/jackc/pgx/v5Go
>= 5.0.0, < 5.5.45.5.4

Affected products

1

Patches

4
c543134753a0

SQL sanitizer wraps arguments in parentheses

https://github.com/jackc/pgxJack ChristensenMar 4, 2024via ghsa
2 files changed · +23 9
  • internal/sanitize/sanitize.go+4 0 modified
    @@ -63,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) {
     				return "", fmt.Errorf("invalid arg type: %T", arg)
     			}
     			argUse[argIdx] = true
    +
    +			// Prevent SQL injection via Line Comment Creation
    +			// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
    +			str = "(" + str + ")"
     		default:
     			return "", fmt.Errorf("invalid Part type: %T", part)
     		}
    
  • internal/sanitize/sanitize_test.go+19 9 modified
    @@ -132,47 +132,57 @@ func TestQuerySanitize(t *testing.T) {
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{int64(42)},
    -			expected: `select 42`,
    +			expected: `select (42)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{float64(1.23)},
    -			expected: `select 1.23`,
    +			expected: `select (1.23)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{true},
    -			expected: `select true`,
    +			expected: `select (true)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{[]byte{0, 1, 2, 3, 255}},
    -			expected: `select '\x00010203ff'`,
    +			expected: `select ('\x00010203ff')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{nil},
    -			expected: `select null`,
    +			expected: `select (null)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{"foobar"},
    -			expected: `select 'foobar'`,
    +			expected: `select ('foobar')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{"foo'bar"},
    -			expected: `select 'foo''bar'`,
    +			expected: `select ('foo''bar')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []any{`foo\'bar`},
    -			expected: `select 'foo\''bar'`,
    +			expected: `select ('foo\''bar')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
     			args:     []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
    -			expected: `insert '2020-03-01 23:59:59.999999Z'`,
    +			expected: `insert ('2020-03-01 23:59:59.999999Z')`,
    +		},
    +		{
    +			query:    sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
    +			args:     []any{int64(-1)},
    +			expected: `select 1-(-1)`,
    +		},
    +		{
    +			query:    sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
    +			args:     []any{float64(-1)},
    +			expected: `select 1-(-1)`,
     		},
     	}
     
    
945c2126f6db

Backport fixes from pgx v5

https://github.com/jackc/pgproto3Jack ChristensenMar 2, 2024via ghsa
59 files changed · +359 359
  • authentication_cleartext_password.go+3 4 modified
    @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 8)
    +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • authentication_gss_continue.go+4 4 modified
    @@ -4,6 +4,7 @@ import (
     	"encoding/binary"
     	"encoding/json"
     	"errors"
    +
     	"github.com/jackc/pgio"
     )
     
    @@ -30,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
     	return nil
     }
     
    -func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
    +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
     	dst = append(dst, a.Data...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
    
  • authentication_gss.go+4 4 modified
    @@ -4,6 +4,7 @@ import (
     	"encoding/binary"
     	"encoding/json"
     	"errors"
    +
     	"github.com/jackc/pgio"
     )
     
    @@ -26,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
     	return nil
     }
     
    -func (a *AuthenticationGSS) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 4)
    +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeGSS)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
    
  • authentication_md5_password.go+3 4 modified
    @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 12)
    +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
     	dst = append(dst, src.Salt[:]...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • authentication_ok.go+3 4 modified
    @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationOk) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 8)
    +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeOk)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • authentication_sasl_continue.go+3 9 modified
    @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
    -
     	dst = append(dst, src.Data...)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • authentication_sasl_final.go+3 9 modified
    @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
    -
     	dst = append(dst, src.Data...)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Unmarshaler.
    
  • authentication_sasl.go+3 7 modified
    @@ -46,10 +46,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationSASL) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeSASL)
     
     	for _, s := range src.AuthMechanisms {
    @@ -58,9 +56,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
     	}
     	dst = append(dst, 0)
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • backend.go+10 5 modified
    @@ -49,7 +49,12 @@ func NewBackend(cr ChunkReader, w io.Writer) *Backend {
     
     // Send sends a message to the frontend.
     func (b *Backend) Send(msg BackendMessage) error {
    -	_, err := b.w.Write(msg.Encode(nil))
    +	buf, err := msg.Encode(nil)
    +	if err != nil {
    +		return err
    +	}
    +
    +	_, err = b.w.Write(buf)
     	return err
     }
     
    @@ -184,11 +189,11 @@ func (b *Backend) Receive() (FrontendMessage, error) {
     // contextual identification of FrontendMessages. For example, in the
     // PG message flow documentation for PasswordMessage:
     //
    -// 		Byte1('p')
    +//			Byte1('p')
     //
    -//      Identifies the message as a password response. Note that this is also used for
    -//		GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
    -//		the context.
    +//	     Identifies the message as a password response. Note that this is also used for
    +//			GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
    +//			the context.
     //
     // Since the Frontend does not know about the state of a backend, it is important
     // to call SetAuthType() after an authentication request is received by the Frontend.
    
  • backend_key_data.go+3 4 modified
    @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *BackendKeyData) Encode(dst []byte) []byte {
    -	dst = append(dst, 'K')
    -	dst = pgio.AppendUint32(dst, 12)
    +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'K')
     	dst = pgio.AppendUint32(dst, src.ProcessID)
     	dst = pgio.AppendUint32(dst, src.SecretKey)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • backend_test.go+2 2 modified
    @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
     				"username": "tester",
     			},
     		}
    -		dst := []byte{}
    -		dst = want.Encode(dst)
    +		dst, err := want.Encode([]byte{})
    +		require.NoError(t, err)
     
     		server := &interruptReader{}
     		server.push(dst)
    
  • bind_complete.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *BindComplete) Encode(dst []byte) []byte {
    -	return append(dst, '2', 0, 0, 0, 4)
    +func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, '2', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • bind.go+14 7 modified
    @@ -5,7 +5,9 @@ import (
     	"encoding/binary"
     	"encoding/hex"
     	"encoding/json"
    +	"errors"
     	"fmt"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Bind) Encode(dst []byte) []byte {
    -	dst = append(dst, 'B')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *Bind) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'B')
     
     	dst = append(dst, src.DestinationPortal...)
     	dst = append(dst, 0)
     	dst = append(dst, src.PreparedStatement...)
     	dst = append(dst, 0)
     
    +	if len(src.ParameterFormatCodes) > math.MaxUint16 {
    +		return nil, errors.New("too many parameter format codes")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
     	for _, fc := range src.ParameterFormatCodes {
     		dst = pgio.AppendInt16(dst, fc)
     	}
     
    +	if len(src.Parameters) > math.MaxUint16 {
    +		return nil, errors.New("too many parameters")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
     	for _, p := range src.Parameters {
     		if p == nil {
    @@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
     		dst = append(dst, p...)
     	}
     
    +	if len(src.ResultFormatCodes) > math.MaxUint16 {
    +		return nil, errors.New("too many result format codes")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
     	for _, fc := range src.ResultFormatCodes {
     		dst = pgio.AppendInt16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • bind_test.go+20 0 added
    @@ -0,0 +1,20 @@
    +package pgproto3_test
    +
    +import (
    +	"testing"
    +
    +	"github.com/jackc/pgproto3/v2"
    +	"github.com/stretchr/testify/require"
    +)
    +
    +func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
    +	t.Parallel()
    +
    +	// Maximum allowed size.
    +	_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
    +	require.NoError(t, err)
    +
    +	// 1 byte too big
    +	_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
    +	require.Error(t, err)
    +}
    
  • cancel_request.go+2 2 modified
    @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 4 byte message length.
    -func (src *CancelRequest) Encode(dst []byte) []byte {
    +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
     	dst = pgio.AppendInt32(dst, 16)
     	dst = pgio.AppendInt32(dst, cancelRequestCode)
     	dst = pgio.AppendUint32(dst, src.ProcessID)
     	dst = pgio.AppendUint32(dst, src.SecretKey)
    -	return dst
    +	return dst, nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • close_complete.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CloseComplete) Encode(dst []byte) []byte {
    -	return append(dst, '3', 0, 0, 0, 4)
    +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, '3', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • close.go+3 11 modified
    @@ -4,8 +4,6 @@ import (
     	"bytes"
     	"encoding/json"
     	"errors"
    -
    -	"github.com/jackc/pgio"
     )
     
     type Close struct {
    @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Close) Encode(dst []byte) []byte {
    -	dst = append(dst, 'C')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *Close) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'C')
     	dst = append(dst, src.ObjectType)
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • command_complete.go+3 11 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type CommandComplete struct {
    @@ -28,17 +26,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CommandComplete) Encode(dst []byte) []byte {
    -	dst = append(dst, 'C')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'C')
     	dst = append(dst, src.CommandTag...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • copy_both_response.go+7 7 modified
    @@ -5,6 +5,7 @@ import (
     	"encoding/binary"
     	"encoding/json"
     	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyBothResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'W')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'W')
     	dst = append(dst, src.OverallFormat)
    +	if len(src.ColumnFormatCodes) > math.MaxUint16 {
    +		return nil, errors.New("too many column format codes")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
     	for _, fc := range src.ColumnFormatCodes {
     		dst = pgio.AppendUint16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • copy_both_response_test.go+3 1 modified
    @@ -5,6 +5,7 @@ import (
     
     	"github.com/jackc/pgproto3/v2"
     	"github.com/stretchr/testify/assert"
    +	"github.com/stretchr/testify/require"
     )
     
     func TestEncodeDecode(t *testing.T) {
    @@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
     	err := dstResp.Decode(srcBytes[5:])
     	assert.NoError(t, err, "No errors on decode")
     	dstBytes := []byte{}
    -	dstBytes = dstResp.Encode(dstBytes)
    +	dstBytes, err = dstResp.Encode(dstBytes)
    +	require.NoError(t, err)
     	assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
     }
    
  • copy_data.go+3 6 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"encoding/hex"
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type CopyData struct {
    @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyData) Encode(dst []byte) []byte {
    -	dst = append(dst, 'd')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
    +func (src *CopyData) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'd')
     	dst = append(dst, src.Data...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • copy_done.go+2 2 modified
    @@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyDone) Encode(dst []byte) []byte {
    -	return append(dst, 'c', 0, 0, 0, 4)
    +func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'c', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • copy_fail.go+3 11 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type CopyFail struct {
    @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyFail) Encode(dst []byte) []byte {
    -	dst = append(dst, 'f')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'f')
     	dst = append(dst, src.Message...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • copy_in_response.go+7 7 modified
    @@ -5,6 +5,7 @@ import (
     	"encoding/binary"
     	"encoding/json"
     	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyInResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'G')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'G')
     
     	dst = append(dst, src.OverallFormat)
    +	if len(src.ColumnFormatCodes) > math.MaxUint16 {
    +		return nil, errors.New("too many column format codes")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
     	for _, fc := range src.ColumnFormatCodes {
     		dst = pgio.AppendUint16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • copy_out_response.go+7 7 modified
    @@ -5,6 +5,7 @@ import (
     	"encoding/binary"
     	"encoding/json"
     	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyOutResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'H')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'H')
     
     	dst = append(dst, src.OverallFormat)
     
    +	if len(src.ColumnFormatCodes) > math.MaxUint16 {
    +		return nil, errors.New("too many column format codes")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
     	for _, fc := range src.ColumnFormatCodes {
     		dst = pgio.AppendUint16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • data_row.go+8 7 modified
    @@ -4,6 +4,8 @@ import (
     	"encoding/binary"
     	"encoding/hex"
     	"encoding/json"
    +	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *DataRow) Encode(dst []byte) []byte {
    -	dst = append(dst, 'D')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *DataRow) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'D')
     
    +	if len(src.Values) > math.MaxUint16 {
    +		return nil, errors.New("too many values")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
     	for _, v := range src.Values {
     		if v == nil {
    @@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
     		dst = append(dst, v...)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • describe.go+3 11 modified
    @@ -4,8 +4,6 @@ import (
     	"bytes"
     	"encoding/json"
     	"errors"
    -
    -	"github.com/jackc/pgio"
     )
     
     type Describe struct {
    @@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Describe) Encode(dst []byte) []byte {
    -	dst = append(dst, 'D')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *Describe) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'D')
     	dst = append(dst, src.ObjectType)
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • empty_query_response.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
    -	return append(dst, 'I', 0, 0, 0, 4)
    +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'I', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • error_response.go+64 72 modified
    @@ -2,7 +2,6 @@ package pgproto3
     
     import (
     	"bytes"
    -	"encoding/binary"
     	"encoding/json"
     	"strconv"
     )
    @@ -111,120 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ErrorResponse) Encode(dst []byte) []byte {
    -	return append(dst, src.marshalBinary('E')...)
    +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'E')
    +	dst = src.appendFields(dst)
    +	return finishMessage(dst, sp)
     }
     
    -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
    -	var bigEndian BigEndianBuf
    -	buf := &bytes.Buffer{}
    -
    -	buf.WriteByte(typeByte)
    -	buf.Write(bigEndian.Uint32(0))
    -
    +func (src *ErrorResponse) appendFields(dst []byte) []byte {
     	if src.Severity != "" {
    -		buf.WriteByte('S')
    -		buf.WriteString(src.Severity)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'S')
    +		dst = append(dst, src.Severity...)
    +		dst = append(dst, 0)
     	}
     	if src.SeverityUnlocalized != "" {
    -		buf.WriteByte('V')
    -		buf.WriteString(src.SeverityUnlocalized)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'V')
    +		dst = append(dst, src.SeverityUnlocalized...)
    +		dst = append(dst, 0)
     	}
     	if src.Code != "" {
    -		buf.WriteByte('C')
    -		buf.WriteString(src.Code)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'C')
    +		dst = append(dst, src.Code...)
    +		dst = append(dst, 0)
     	}
     	if src.Message != "" {
    -		buf.WriteByte('M')
    -		buf.WriteString(src.Message)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'M')
    +		dst = append(dst, src.Message...)
    +		dst = append(dst, 0)
     	}
     	if src.Detail != "" {
    -		buf.WriteByte('D')
    -		buf.WriteString(src.Detail)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'D')
    +		dst = append(dst, src.Detail...)
    +		dst = append(dst, 0)
     	}
     	if src.Hint != "" {
    -		buf.WriteByte('H')
    -		buf.WriteString(src.Hint)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'H')
    +		dst = append(dst, src.Hint...)
    +		dst = append(dst, 0)
     	}
     	if src.Position != 0 {
    -		buf.WriteByte('P')
    -		buf.WriteString(strconv.Itoa(int(src.Position)))
    -		buf.WriteByte(0)
    +		dst = append(dst, 'P')
    +		dst = append(dst, strconv.Itoa(int(src.Position))...)
    +		dst = append(dst, 0)
     	}
     	if src.InternalPosition != 0 {
    -		buf.WriteByte('p')
    -		buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
    -		buf.WriteByte(0)
    +		dst = append(dst, 'p')
    +		dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
    +		dst = append(dst, 0)
     	}
     	if src.InternalQuery != "" {
    -		buf.WriteByte('q')
    -		buf.WriteString(src.InternalQuery)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'q')
    +		dst = append(dst, src.InternalQuery...)
    +		dst = append(dst, 0)
     	}
     	if src.Where != "" {
    -		buf.WriteByte('W')
    -		buf.WriteString(src.Where)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'W')
    +		dst = append(dst, src.Where...)
    +		dst = append(dst, 0)
     	}
     	if src.SchemaName != "" {
    -		buf.WriteByte('s')
    -		buf.WriteString(src.SchemaName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 's')
    +		dst = append(dst, src.SchemaName...)
    +		dst = append(dst, 0)
     	}
     	if src.TableName != "" {
    -		buf.WriteByte('t')
    -		buf.WriteString(src.TableName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 't')
    +		dst = append(dst, src.TableName...)
    +		dst = append(dst, 0)
     	}
     	if src.ColumnName != "" {
    -		buf.WriteByte('c')
    -		buf.WriteString(src.ColumnName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'c')
    +		dst = append(dst, src.ColumnName...)
    +		dst = append(dst, 0)
     	}
     	if src.DataTypeName != "" {
    -		buf.WriteByte('d')
    -		buf.WriteString(src.DataTypeName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'd')
    +		dst = append(dst, src.DataTypeName...)
    +		dst = append(dst, 0)
     	}
     	if src.ConstraintName != "" {
    -		buf.WriteByte('n')
    -		buf.WriteString(src.ConstraintName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'n')
    +		dst = append(dst, src.ConstraintName...)
    +		dst = append(dst, 0)
     	}
     	if src.File != "" {
    -		buf.WriteByte('F')
    -		buf.WriteString(src.File)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'F')
    +		dst = append(dst, src.File...)
    +		dst = append(dst, 0)
     	}
     	if src.Line != 0 {
    -		buf.WriteByte('L')
    -		buf.WriteString(strconv.Itoa(int(src.Line)))
    -		buf.WriteByte(0)
    +		dst = append(dst, 'L')
    +		dst = append(dst, strconv.Itoa(int(src.Line))...)
    +		dst = append(dst, 0)
     	}
     	if src.Routine != "" {
    -		buf.WriteByte('R')
    -		buf.WriteString(src.Routine)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'R')
    +		dst = append(dst, src.Routine...)
    +		dst = append(dst, 0)
     	}
     
     	for k, v := range src.UnknownFields {
    -		buf.WriteByte(k)
    -		buf.WriteByte(0)
    -		buf.WriteString(v)
    -		buf.WriteByte(0)
    +		dst = append(dst, k)
    +		dst = append(dst, v...)
    +		dst = append(dst, 0)
     	}
     
    -	buf.WriteByte(0)
    -
    -	binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
    +	dst = append(dst, 0)
     
    -	return buf.Bytes()
    +	return dst
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • example/pgfortune/server.go+14 7 modified
    @@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
     				return fmt.Errorf("error generating query response: %w", err)
     			}
     
    -			buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
    +			buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
     				{
     					Name:                 []byte("fortune"),
     					TableOID:             0,
    @@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
     					TypeModifier:         -1,
     					Format:               0,
     				},
    -			}}).Encode(nil)
    -			buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
    -			buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
    -			buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
    +			}}).Encode(nil))
    +			buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
    +			buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
    +			buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
     			_, err = p.conn.Write(buf)
     			if err != nil {
     				return fmt.Errorf("error writing query response: %w", err)
    @@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
     
     	switch startupMessage.(type) {
     	case *pgproto3.StartupMessage:
    -		buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
    -		buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
    +		buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
    +		buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
     		_, err = p.conn.Write(buf)
     		if err != nil {
     			return fmt.Errorf("error sending ready for query: %w", err)
    @@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
     func (p *PgFortuneBackend) Close() error {
     	return p.conn.Close()
     }
    +
    +func mustEncode(buf []byte, err error) []byte {
    +	if err != nil {
    +		panic(err)
    +	}
    +	return buf
    +}
    
  • execute.go+3 10 modified
    @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Execute) Encode(dst []byte) []byte {
    -	dst = append(dst, 'E')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *Execute) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'E')
     	dst = append(dst, src.Portal...)
     	dst = append(dst, 0)
    -
     	dst = pgio.AppendUint32(dst, src.MaxRows)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • flush.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Flush) Encode(dst []byte) []byte {
    -	return append(dst, 'H', 0, 0, 0, 4)
    +func (src *Flush) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'H', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • frontend.go+5 1 modified
    @@ -57,7 +57,11 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
     
     // Send sends a message to the backend.
     func (f *Frontend) Send(msg FrontendMessage) error {
    -	_, err := f.w.Write(msg.Encode(nil))
    +	buf, err := msg.Encode(nil)
    +	if err != nil {
    +		return err
    +	}
    +	_, err = f.w.Write(buf)
     	return err
     }
     
    
  • function_call.go+14 6 modified
    @@ -2,6 +2,9 @@ package pgproto3
     
     import (
     	"encoding/binary"
    +	"errors"
    +	"math"
    +
     	"github.com/jackc/pgio"
     )
     
    @@ -70,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *FunctionCall) Encode(dst []byte) []byte {
    -	dst = append(dst, 'F')
    -	sp := len(dst)
    -	dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
    +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'F')
     	dst = pgio.AppendUint32(dst, src.Function)
    +
    +	if len(src.ArgFormatCodes) > math.MaxUint16 {
    +		return nil, errors.New("too many arg format codes")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
     	for _, argFormatCode := range src.ArgFormatCodes {
     		dst = pgio.AppendUint16(dst, argFormatCode)
     	}
    +
    +	if len(src.Arguments) > math.MaxUint16 {
    +		return nil, errors.New("too many arguments")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
     	for _, argument := range src.Arguments {
     		if argument == nil {
    @@ -89,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
     		}
     	}
     	dst = pgio.AppendUint16(dst, src.ResultFormatCode)
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -	return dst
    +	return finishMessage(dst, sp)
     }
    
  • function_call_response.go+3 7 modified
    @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *FunctionCallResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'V')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'V')
     
     	if src.Result == nil {
     		dst = pgio.AppendInt32(dst, -1)
    @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
     		dst = append(dst, src.Result...)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • function_call_test.go+5 2 modified
    @@ -4,6 +4,8 @@ import (
     	"encoding/binary"
     	"reflect"
     	"testing"
    +
    +	"github.com/stretchr/testify/require"
     )
     
     func TestFunctionCall_EncodeDecode(t *testing.T) {
    @@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
     				Arguments:        tt.fields.Arguments,
     				ResultFormatCode: tt.fields.ResultFormatCode,
     			}
    -			encoded := src.Encode([]byte{})
    +			encoded, err := src.Encode([]byte{})
    +			require.NoError(t, err)
     			dst := &FunctionCall{}
     			// Check the header
     			msgTypeCode := encoded[0]
    @@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
     				t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
     			}
     			// Check decoding works as expected
    -			err := dst.Decode(encoded[5:])
    +			err = dst.Decode(encoded[5:])
     			if err != nil {
     				if !tt.wantErr {
     					t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
    
  • gss_enc_request.go+2 2 modified
    @@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 4 byte message length.
    -func (src *GSSEncRequest) Encode(dst []byte) []byte {
    +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
     	dst = pgio.AppendInt32(dst, 8)
     	dst = pgio.AppendInt32(dst, gssEncReqNumber)
    -	return dst
    +	return dst, nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • gss_response.go+3 5 modified
    @@ -2,7 +2,6 @@ package pgproto3
     
     import (
     	"encoding/json"
    -	"github.com/jackc/pgio"
     )
     
     type GSSResponse struct {
    @@ -17,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
     	return nil
     }
     
    -func (g *GSSResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
    +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     	dst = append(dst, g.Data...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • no_data.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *NoData) Encode(dst []byte) []byte {
    -	return append(dst, 'n', 0, 0, 0, 4)
    +func (src *NoData) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'n', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • notice_response.go+4 2 modified
    @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *NoticeResponse) Encode(dst []byte) []byte {
    -	return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
    +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'N')
    +	dst = (*ErrorResponse)(src).appendFields(dst)
    +	return finishMessage(dst, sp)
     }
    
  • notification_response.go+3 9 modified
    @@ -41,20 +41,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *NotificationResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'A')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'A')
     	dst = pgio.AppendUint32(dst, src.PID)
     	dst = append(dst, src.Channel...)
     	dst = append(dst, 0)
     	dst = append(dst, src.Payload...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • parameter_description.go+8 7 modified
    @@ -4,6 +4,8 @@ import (
     	"bytes"
     	"encoding/binary"
     	"encoding/json"
    +	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ParameterDescription) Encode(dst []byte) []byte {
    -	dst = append(dst, 't')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 't')
     
    +	if len(src.ParameterOIDs) > math.MaxUint16 {
    +		return nil, errors.New("too many parameter oids")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
     	for _, oid := range src.ParameterOIDs {
     		dst = pgio.AppendUint32(dst, oid)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • parameter_status.go+3 11 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type ParameterStatus struct {
    @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ParameterStatus) Encode(dst []byte) []byte {
    -	dst = append(dst, 'S')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'S')
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
     	dst = append(dst, src.Value...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • parse_complete.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ParseComplete) Encode(dst []byte) []byte {
    -	return append(dst, '1', 0, 0, 0, 4)
    +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, '1', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • parse.go+8 7 modified
    @@ -4,6 +4,8 @@ import (
     	"bytes"
     	"encoding/binary"
     	"encoding/json"
    +	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Parse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'P')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *Parse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'P')
     
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
     	dst = append(dst, src.Query...)
     	dst = append(dst, 0)
     
    +	if len(src.ParameterOIDs) > math.MaxUint16 {
    +		return nil, errors.New("too many parameter oids")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
     	for _, oid := range src.ParameterOIDs {
     		dst = pgio.AppendUint32(dst, oid)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • password_message.go+3 8 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type PasswordMessage struct {
    @@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *PasswordMessage) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
    -
    +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     	dst = append(dst, src.Password...)
     	dst = append(dst, 0)
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3.go+27 1 modified
    @@ -4,8 +4,14 @@ import (
     	"encoding/hex"
     	"errors"
     	"fmt"
    +
    +	"github.com/jackc/pgio"
     )
     
    +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
    +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
    +const maxMessageBodyLen = (0x3fffffff - 1)
    +
     // Message is the interface implemented by an object that can decode and encode
     // a particular PostgreSQL message.
     type Message interface {
    @@ -14,7 +20,7 @@ type Message interface {
     	Decode(data []byte) error
     
     	// Encode appends itself to dst and returns the new buffer.
    -	Encode(dst []byte) []byte
    +	Encode(dst []byte) ([]byte, error)
     }
     
     type FrontendMessage interface {
    @@ -63,3 +69,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
     	}
     	return nil, errors.New("unknown protocol representation")
     }
    +
    +// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
    +// dst. It returns the new buffer and the position of the message length placeholder.
    +func beginMessage(dst []byte, t byte) ([]byte, int) {
    +	dst = append(dst, t)
    +	sp := len(dst)
    +	dst = pgio.AppendInt32(dst, -1)
    +	return dst, sp
    +}
    +
    +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
    +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
    +func finishMessage(dst []byte, sp int) ([]byte, error) {
    +	messageBodyLen := len(dst[sp:])
    +	if messageBodyLen > maxMessageBodyLen {
    +		return nil, errors.New("message body too large")
    +	}
    +	pgio.SetInt32(dst[sp:], int32(messageBodyLen))
    +	return dst, nil
    +}
    
  • pgproto3_private_test.go+3 0 added
    @@ -0,0 +1,3 @@
    +package pgproto3
    +
    +const MaxMessageBodyLen = maxMessageBodyLen
    
  • portal_suspended.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *PortalSuspended) Encode(dst []byte) []byte {
    -	return append(dst, 's', 0, 0, 0, 4)
    +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 's', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • query.go+3 8 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type Query struct {
    @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Query) Encode(dst []byte) []byte {
    -	dst = append(dst, 'Q')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
    -
    +func (src *Query) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'Q')
     	dst = append(dst, src.String...)
     	dst = append(dst, 0)
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • query_test.go+20 0 added
    @@ -0,0 +1,20 @@
    +package pgproto3_test
    +
    +import (
    +	"testing"
    +
    +	"github.com/jackc/pgproto3/v2"
    +	"github.com/stretchr/testify/require"
    +)
    +
    +func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) {
    +	t.Parallel()
    +
    +	// Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string.
    +	_, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil)
    +	require.NoError(t, err)
    +
    +	// 1 byte too big
    +	_, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil)
    +	require.Error(t, err)
    +}
    
  • ready_for_query.go+2 2 modified
    @@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ReadyForQuery) Encode(dst []byte) []byte {
    -	return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
    +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • row_description.go+8 7 modified
    @@ -4,6 +4,8 @@ import (
     	"bytes"
     	"encoding/binary"
     	"encoding/json"
    +	"errors"
    +	"math"
     
     	"github.com/jackc/pgio"
     )
    @@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *RowDescription) Encode(dst []byte) []byte {
    -	dst = append(dst, 'T')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'T')
     
    +	if len(src.Fields) > math.MaxUint16 {
    +		return nil, errors.New("too many fields")
    +	}
     	dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
     	for _, fd := range src.Fields {
     		dst = append(dst, fd.Name...)
    @@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
     		dst = pgio.AppendInt16(dst, fd.Format)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • sasl_initial_response.go+3 7 modified
    @@ -38,20 +38,16 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *SASLInitialResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     
     	dst = append(dst, []byte(src.AuthMechanism)...)
     	dst = append(dst, 0)
     
     	dst = pgio.AppendInt32(dst, int32(len(src.Data)))
     	dst = append(dst, src.Data...)
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • sasl_response.go+3 8 modified
    @@ -2,8 +2,6 @@ package pgproto3
     
     import (
     	"encoding/json"
    -
    -	"github.com/jackc/pgio"
     )
     
     type SASLResponse struct {
    @@ -21,13 +19,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *SASLResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
    -
    +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     	dst = append(dst, src.Data...)
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • ssl_request.go+2 2 modified
    @@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 4 byte message length.
    -func (src *SSLRequest) Encode(dst []byte) []byte {
    +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
     	dst = pgio.AppendInt32(dst, 8)
     	dst = pgio.AppendInt32(dst, sslRequestNumber)
    -	return dst
    +	return dst, nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • startup_message.go+2 4 modified
    @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *StartupMessage) Encode(dst []byte) []byte {
    +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
     	sp := len(dst)
     	dst = pgio.AppendInt32(dst, -1)
     
    @@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
     	}
     	dst = append(dst, 0)
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • sync.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Sync) Encode(dst []byte) []byte {
    -	return append(dst, 'S', 0, 0, 0, 4)
    +func (src *Sync) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'S', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • terminate.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Terminate) Encode(dst []byte) []byte {
    -	return append(dst, 'X', 0, 0, 0, 4)
    +func (src *Terminate) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'X', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
adbb38f298c7

Do not allow protocol messages larger than ~1GB

https://github.com/jackc/pgxJack ChristensenMar 2, 2024via ghsa
61 files changed · +472 390
  • pgconn/pgconn.go+41 5 modified
    @@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
     // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
     type Batch struct {
     	buf []byte
    +	err error
     }
     
     // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
     func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
    -	batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
    +	if batch.err != nil {
    +		return
    +	}
    +
    +	batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
    +	if batch.err != nil {
    +		return
    +	}
     	batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
     }
     
     // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
     func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
    -	batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
    -	batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
    -	batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
    +	if batch.err != nil {
    +		return
    +	}
    +
    +	batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
    +	if batch.err != nil {
    +		return
    +	}
    +
    +	batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
    +	if batch.err != nil {
    +		return
    +	}
    +
    +	batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
    +	if batch.err != nil {
    +		return
    +	}
     }
     
     // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
     // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
     // multiple queries in a single round trip than using pipeline mode.
     func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
    +	if batch.err != nil {
    +		return &MultiResultReader{
    +			closed: true,
    +			err:    batch.err,
    +		}
    +	}
    +
     	if err := pgConn.lock(); err != nil {
     		return &MultiResultReader{
     			closed: true,
    @@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
     		pgConn.contextWatcher.Watch(ctx)
     	}
     
    -	batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
    +	batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
    +	if batch.err != nil {
    +		multiResult.closed = true
    +		multiResult.err = batch.err
    +		pgConn.unlock()
    +		return multiResult
    +	}
     
     	pgConn.enterPotentialWriteReadDeadlock()
     	defer pgConn.exitPotentialWriteReadDeadlock()
    
  • pgconn/pgconn_test.go+10 3 modified
    @@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) {
     					return
     				}
     
    -				srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
    -				srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
    -				srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
    +				srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
    +				srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
    +				srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
     
     				serverSNINameChan <- sniHost
     			}()
    @@ -3472,3 +3472,10 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
     	err = pipeline.Close()
     	require.Error(t, err)
     }
    +
    +func mustEncode(buf []byte, err error) []byte {
    +	if err != nil {
    +		panic(err)
    +	}
    +	return buf
    +}
    
  • pgproto3/authentication_cleartext_password.go+3 4 modified
    @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 8)
    +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/authentication_gss_continue.go+3 4 modified
    @@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
     	return nil
     }
     
    -func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
    +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
     	dst = append(dst, a.Data...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
    
  • pgproto3/authentication_gss.go+3 4 modified
    @@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
     	return nil
     }
     
    -func (a *AuthenticationGSS) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 4)
    +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeGSS)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
    
  • pgproto3/authentication_md5_password.go+3 4 modified
    @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 12)
    +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
     	dst = append(dst, src.Salt[:]...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/authentication_ok.go+3 4 modified
    @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationOk) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	dst = pgio.AppendInt32(dst, 8)
    +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeOk)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/authentication_sasl_continue.go+3 9 modified
    @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
    -
     	dst = append(dst, src.Data...)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/authentication_sasl_final.go+3 9 modified
    @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
    -
     	dst = append(dst, src.Data...)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Unmarshaler.
    
  • pgproto3/authentication_sasl.go+3 7 modified
    @@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *AuthenticationSASL) Encode(dst []byte) []byte {
    -	dst = append(dst, 'R')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'R')
     	dst = pgio.AppendUint32(dst, AuthTypeSASL)
     
     	for _, s := range src.AuthMechanisms {
    @@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
     	}
     	dst = append(dst, 0)
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/backend.go+21 4 modified
    @@ -16,7 +16,8 @@ type Backend struct {
     	// before it is actually transmitted (i.e. before Flush).
     	tracer *tracer
     
    -	wbuf []byte
    +	wbuf        []byte
    +	encodeError error
     
     	// Frontend message flyweights
     	bind           Bind
    @@ -55,18 +56,34 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
     	return &Backend{cr: cr, w: w}
     }
     
    -// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
    -// called.
    +// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
    +// encountered will be returned from Flush.
     func (b *Backend) Send(msg BackendMessage) {
    +	if b.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(b.wbuf)
    -	b.wbuf = msg.Encode(b.wbuf)
    +	newBuf, err := msg.Encode(b.wbuf)
    +	if err != nil {
    +		b.encodeError = err
    +		return
    +	}
    +	b.wbuf = newBuf
    +
     	if b.tracer != nil {
     		b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
     	}
     }
     
     // Flush writes any pending messages to the frontend (i.e. the client).
     func (b *Backend) Flush() error {
    +	if err := b.encodeError; err != nil {
    +		b.encodeError = nil
    +		b.wbuf = b.wbuf[:0]
    +		return &writeError{err: err, safeToRetry: true}
    +	}
    +
     	n, err := b.w.Write(b.wbuf)
     
     	const maxLen = 1024
    
  • pgproto3/backend_key_data.go+3 4 modified
    @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *BackendKeyData) Encode(dst []byte) []byte {
    -	dst = append(dst, 'K')
    -	dst = pgio.AppendUint32(dst, 12)
    +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'K')
     	dst = pgio.AppendUint32(dst, src.ProcessID)
     	dst = pgio.AppendUint32(dst, src.SecretKey)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/backend_test.go+2 2 modified
    @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
     				"username": "tester",
     			},
     		}
    -		dst := []byte{}
    -		dst = want.Encode(dst)
    +		dst, err := want.Encode([]byte{})
    +		require.NoError(t, err)
     
     		server := &interruptReader{}
     		server.push(dst)
    
  • pgproto3/bind_complete.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *BindComplete) Encode(dst []byte) []byte {
    -	return append(dst, '2', 0, 0, 0, 4)
    +func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, '2', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/bind.go+3 7 modified
    @@ -108,10 +108,8 @@ func (dst *Bind) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Bind) Encode(dst []byte) []byte {
    -	dst = append(dst, 'B')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *Bind) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'B')
     
     	dst = append(dst, src.DestinationPortal...)
     	dst = append(dst, 0)
    @@ -139,9 +137,7 @@ func (src *Bind) Encode(dst []byte) []byte {
     		dst = pgio.AppendInt16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/bind_test.go+20 0 added
    @@ -0,0 +1,20 @@
    +package pgproto3_test
    +
    +import (
    +	"testing"
    +
    +	"github.com/jackc/pgx/v5/pgproto3"
    +	"github.com/stretchr/testify/require"
    +)
    +
    +func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
    +	t.Parallel()
    +
    +	// Maximum allowed size.
    +	_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
    +	require.NoError(t, err)
    +
    +	// 1 byte too big
    +	_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
    +	require.Error(t, err)
    +}
    
  • pgproto3/cancel_request.go+2 2 modified
    @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 4 byte message length.
    -func (src *CancelRequest) Encode(dst []byte) []byte {
    +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
     	dst = pgio.AppendInt32(dst, 16)
     	dst = pgio.AppendInt32(dst, cancelRequestCode)
     	dst = pgio.AppendUint32(dst, src.ProcessID)
     	dst = pgio.AppendUint32(dst, src.SecretKey)
    -	return dst
    +	return dst, nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/close_complete.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CloseComplete) Encode(dst []byte) []byte {
    -	return append(dst, '3', 0, 0, 0, 4)
    +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, '3', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/close.go+3 11 modified
    @@ -4,8 +4,6 @@ import (
     	"bytes"
     	"encoding/json"
     	"errors"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type Close struct {
    @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Close) Encode(dst []byte) []byte {
    -	dst = append(dst, 'C')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *Close) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'C')
     	dst = append(dst, src.ObjectType)
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/command_complete.go+3 11 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type CommandComplete struct {
    @@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CommandComplete) Encode(dst []byte) []byte {
    -	dst = append(dst, 'C')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'C')
     	dst = append(dst, src.CommandTag...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/copy_both_response.go+3 7 modified
    @@ -44,19 +44,15 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyBothResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'W')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'W')
     	dst = append(dst, src.OverallFormat)
     	dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
     	for _, fc := range src.ColumnFormatCodes {
     		dst = pgio.AppendUint16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/copy_both_response_test.go+3 1 modified
    @@ -5,6 +5,7 @@ import (
     
     	"github.com/jackc/pgx/v5/pgproto3"
     	"github.com/stretchr/testify/assert"
    +	"github.com/stretchr/testify/require"
     )
     
     func TestEncodeDecode(t *testing.T) {
    @@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
     	err := dstResp.Decode(srcBytes[5:])
     	assert.NoError(t, err, "No errors on decode")
     	dstBytes := []byte{}
    -	dstBytes = dstResp.Encode(dstBytes)
    +	dstBytes, err = dstResp.Encode(dstBytes)
    +	require.NoError(t, err)
     	assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
     }
    
  • pgproto3/copy_data.go+3 6 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"encoding/hex"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type CopyData struct {
    @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyData) Encode(dst []byte) []byte {
    -	dst = append(dst, 'd')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
    +func (src *CopyData) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'd')
     	dst = append(dst, src.Data...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/copy_done.go+2 2 modified
    @@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyDone) Encode(dst []byte) []byte {
    -	return append(dst, 'c', 0, 0, 0, 4)
    +func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'c', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/copy_fail.go+3 11 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type CopyFail struct {
    @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyFail) Encode(dst []byte) []byte {
    -	dst = append(dst, 'f')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'f')
     	dst = append(dst, src.Message...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/copy_in_response.go+3 7 modified
    @@ -44,20 +44,16 @@ func (dst *CopyInResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyInResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'G')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'G')
     
     	dst = append(dst, src.OverallFormat)
     	dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
     	for _, fc := range src.ColumnFormatCodes {
     		dst = pgio.AppendUint16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/copy_out_response.go+3 7 modified
    @@ -43,10 +43,8 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *CopyOutResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'H')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'H')
     
     	dst = append(dst, src.OverallFormat)
     
    @@ -55,9 +53,7 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte {
     		dst = pgio.AppendUint16(dst, fc)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/data_row.go+3 7 modified
    @@ -63,10 +63,8 @@ func (dst *DataRow) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *DataRow) Encode(dst []byte) []byte {
    -	dst = append(dst, 'D')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *DataRow) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'D')
     
     	dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
     	for _, v := range src.Values {
    @@ -79,9 +77,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
     		dst = append(dst, v...)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/describe.go+3 11 modified
    @@ -4,8 +4,6 @@ import (
     	"bytes"
     	"encoding/json"
     	"errors"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type Describe struct {
    @@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Describe) Encode(dst []byte) []byte {
    -	dst = append(dst, 'D')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *Describe) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'D')
     	dst = append(dst, src.ObjectType)
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/empty_query_response.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
    -	return append(dst, 'I', 0, 0, 0, 4)
    +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'I', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/error_response.go+64 71 modified
    @@ -2,7 +2,6 @@ package pgproto3
     
     import (
     	"bytes"
    -	"encoding/binary"
     	"encoding/json"
     	"strconv"
     )
    @@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ErrorResponse) Encode(dst []byte) []byte {
    -	return append(dst, src.marshalBinary('E')...)
    +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'E')
    +	dst = src.appendFields(dst)
    +	return finishMessage(dst, sp)
     }
     
    -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
    -	var bigEndian BigEndianBuf
    -	buf := &bytes.Buffer{}
    -
    -	buf.WriteByte(typeByte)
    -	buf.Write(bigEndian.Uint32(0))
    -
    +func (src *ErrorResponse) appendFields(dst []byte) []byte {
     	if src.Severity != "" {
    -		buf.WriteByte('S')
    -		buf.WriteString(src.Severity)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'S')
    +		dst = append(dst, src.Severity...)
    +		dst = append(dst, 0)
     	}
     	if src.SeverityUnlocalized != "" {
    -		buf.WriteByte('V')
    -		buf.WriteString(src.SeverityUnlocalized)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'V')
    +		dst = append(dst, src.SeverityUnlocalized...)
    +		dst = append(dst, 0)
     	}
     	if src.Code != "" {
    -		buf.WriteByte('C')
    -		buf.WriteString(src.Code)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'C')
    +		dst = append(dst, src.Code...)
    +		dst = append(dst, 0)
     	}
     	if src.Message != "" {
    -		buf.WriteByte('M')
    -		buf.WriteString(src.Message)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'M')
    +		dst = append(dst, src.Message...)
    +		dst = append(dst, 0)
     	}
     	if src.Detail != "" {
    -		buf.WriteByte('D')
    -		buf.WriteString(src.Detail)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'D')
    +		dst = append(dst, src.Detail...)
    +		dst = append(dst, 0)
     	}
     	if src.Hint != "" {
    -		buf.WriteByte('H')
    -		buf.WriteString(src.Hint)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'H')
    +		dst = append(dst, src.Hint...)
    +		dst = append(dst, 0)
     	}
     	if src.Position != 0 {
    -		buf.WriteByte('P')
    -		buf.WriteString(strconv.Itoa(int(src.Position)))
    -		buf.WriteByte(0)
    +		dst = append(dst, 'P')
    +		dst = append(dst, strconv.Itoa(int(src.Position))...)
    +		dst = append(dst, 0)
     	}
     	if src.InternalPosition != 0 {
    -		buf.WriteByte('p')
    -		buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
    -		buf.WriteByte(0)
    +		dst = append(dst, 'p')
    +		dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
    +		dst = append(dst, 0)
     	}
     	if src.InternalQuery != "" {
    -		buf.WriteByte('q')
    -		buf.WriteString(src.InternalQuery)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'q')
    +		dst = append(dst, src.InternalQuery...)
    +		dst = append(dst, 0)
     	}
     	if src.Where != "" {
    -		buf.WriteByte('W')
    -		buf.WriteString(src.Where)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'W')
    +		dst = append(dst, src.Where...)
    +		dst = append(dst, 0)
     	}
     	if src.SchemaName != "" {
    -		buf.WriteByte('s')
    -		buf.WriteString(src.SchemaName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 's')
    +		dst = append(dst, src.SchemaName...)
    +		dst = append(dst, 0)
     	}
     	if src.TableName != "" {
    -		buf.WriteByte('t')
    -		buf.WriteString(src.TableName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 't')
    +		dst = append(dst, src.TableName...)
    +		dst = append(dst, 0)
     	}
     	if src.ColumnName != "" {
    -		buf.WriteByte('c')
    -		buf.WriteString(src.ColumnName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'c')
    +		dst = append(dst, src.ColumnName...)
    +		dst = append(dst, 0)
     	}
     	if src.DataTypeName != "" {
    -		buf.WriteByte('d')
    -		buf.WriteString(src.DataTypeName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'd')
    +		dst = append(dst, src.DataTypeName...)
    +		dst = append(dst, 0)
     	}
     	if src.ConstraintName != "" {
    -		buf.WriteByte('n')
    -		buf.WriteString(src.ConstraintName)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'n')
    +		dst = append(dst, src.ConstraintName...)
    +		dst = append(dst, 0)
     	}
     	if src.File != "" {
    -		buf.WriteByte('F')
    -		buf.WriteString(src.File)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'F')
    +		dst = append(dst, src.File...)
    +		dst = append(dst, 0)
     	}
     	if src.Line != 0 {
    -		buf.WriteByte('L')
    -		buf.WriteString(strconv.Itoa(int(src.Line)))
    -		buf.WriteByte(0)
    +		dst = append(dst, 'L')
    +		dst = append(dst, strconv.Itoa(int(src.Line))...)
    +		dst = append(dst, 0)
     	}
     	if src.Routine != "" {
    -		buf.WriteByte('R')
    -		buf.WriteString(src.Routine)
    -		buf.WriteByte(0)
    +		dst = append(dst, 'R')
    +		dst = append(dst, src.Routine...)
    +		dst = append(dst, 0)
     	}
     
     	for k, v := range src.UnknownFields {
    -		buf.WriteByte(k)
    -		buf.WriteString(v)
    -		buf.WriteByte(0)
    +		dst = append(dst, k)
    +		dst = append(dst, v...)
    +		dst = append(dst, 0)
     	}
     
    -	buf.WriteByte(0)
    -
    -	binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
    +	dst = append(dst, 0)
     
    -	return buf.Bytes()
    +	return dst
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/example/pgfortune/server.go+14 7 modified
    @@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
     				return fmt.Errorf("error generating query response: %w", err)
     			}
     
    -			buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
    +			buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
     				{
     					Name:                 []byte("fortune"),
     					TableOID:             0,
    @@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
     					TypeModifier:         -1,
     					Format:               0,
     				},
    -			}}).Encode(nil)
    -			buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)
    -			buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)
    -			buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
    +			}}).Encode(nil))
    +			buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
    +			buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
    +			buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
     			_, err = p.conn.Write(buf)
     			if err != nil {
     				return fmt.Errorf("error writing query response: %w", err)
    @@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
     
     	switch startupMessage.(type) {
     	case *pgproto3.StartupMessage:
    -		buf := (&pgproto3.AuthenticationOk{}).Encode(nil)
    -		buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)
    +		buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
    +		buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
     		_, err = p.conn.Write(buf)
     		if err != nil {
     			return fmt.Errorf("error sending ready for query: %w", err)
    @@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
     func (p *PgFortuneBackend) Close() error {
     	return p.conn.Close()
     }
    +
    +func mustEncode(buf []byte, err error) []byte {
    +	if err != nil {
    +		panic(err)
    +	}
    +	return buf
    +}
    
  • pgproto3/execute.go+3 10 modified
    @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Execute) Encode(dst []byte) []byte {
    -	dst = append(dst, 'E')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *Execute) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'E')
     	dst = append(dst, src.Portal...)
     	dst = append(dst, 0)
    -
     	dst = pgio.AppendUint32(dst, src.MaxRows)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/flush.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Flush) Encode(dst []byte) []byte {
    -	return append(dst, 'H', 0, 0, 0, 4)
    +func (src *Flush) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'H', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/frontend.go+112 25 modified
    @@ -18,7 +18,8 @@ type Frontend struct {
     	// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
     	tracer *tracer
     
    -	wbuf []byte
    +	wbuf        []byte
    +	encodeError error
     
     	// Backend message flyweights
     	authenticationOk                AuthenticationOk
    @@ -64,23 +65,39 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
     	return &Frontend{cr: cr, w: w}
     }
     
    -// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
    -// called.
    +// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
    +// encountered will be returned from Flush.
     //
     // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
     // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
     // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
     // behind an interface.
     func (f *Frontend) Send(msg FrontendMessage) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
     // Flush writes any pending messages to the backend (i.e. the server).
     func (f *Frontend) Flush() error {
    +	if err := f.encodeError; err != nil {
    +		f.encodeError = nil
    +		f.wbuf = f.wbuf[:0]
    +		return &writeError{err: err, safeToRetry: true}
    +	}
    +
     	if len(f.wbuf) == 0 {
     		return nil
     	}
    @@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
     	f.tracer = nil
     }
     
    -// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
    +// error encountered will be returned from Flush.
     func (f *Frontend) SendBind(msg *Bind) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
    -// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
    +// error encountered will be returned from Flush.
     func (f *Frontend) SendParse(msg *Parse) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
    -// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
    +// error encountered will be returned from Flush.
     func (f *Frontend) SendClose(msg *Close) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
    -// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
    +// called. Any error encountered will be returned from Flush.
     func (f *Frontend) SendDescribe(msg *Describe) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
    -// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
    +// Any error encountered will be returned from Flush.
     func (f *Frontend) SendExecute(msg *Execute) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
    -// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
    +// error encountered will be returned from Flush.
     func (f *Frontend) SendSync(msg *Sync) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
     	}
     }
     
    -// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
    -// Flush is called.
    +// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
    +// error encountered will be returned from Flush.
     func (f *Frontend) SendQuery(msg *Query) {
    +	if f.encodeError != nil {
    +		return
    +	}
    +
     	prevLen := len(f.wbuf)
    -	f.wbuf = msg.Encode(f.wbuf)
    +	newBuf, err := msg.Encode(f.wbuf)
    +	if err != nil {
    +		f.encodeError = err
    +		return
    +	}
    +	f.wbuf = newBuf
    +
     	if f.tracer != nil {
     		f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
     	}
    
  • pgproto3/function_call.go+3 6 modified
    @@ -71,10 +71,8 @@ func (dst *FunctionCall) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *FunctionCall) Encode(dst []byte) []byte {
    -	dst = append(dst, 'F')
    -	sp := len(dst)
    -	dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
    +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'F')
     	dst = pgio.AppendUint32(dst, src.Function)
     	dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
     	for _, argFormatCode := range src.ArgFormatCodes {
    @@ -90,6 +88,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
     		}
     	}
     	dst = pgio.AppendUint16(dst, src.ResultFormatCode)
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -	return dst
    +	return finishMessage(dst, sp)
     }
    
  • pgproto3/function_call_response.go+3 7 modified
    @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *FunctionCallResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'V')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'V')
     
     	if src.Result == nil {
     		dst = pgio.AppendInt32(dst, -1)
    @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
     		dst = append(dst, src.Result...)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/function_call_test.go+5 2 modified
    @@ -4,6 +4,8 @@ import (
     	"encoding/binary"
     	"reflect"
     	"testing"
    +
    +	"github.com/stretchr/testify/require"
     )
     
     func TestFunctionCall_EncodeDecode(t *testing.T) {
    @@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
     				Arguments:        tt.fields.Arguments,
     				ResultFormatCode: tt.fields.ResultFormatCode,
     			}
    -			encoded := src.Encode([]byte{})
    +			encoded, err := src.Encode([]byte{})
    +			require.NoError(t, err)
     			dst := &FunctionCall{}
     			// Check the header
     			msgTypeCode := encoded[0]
    @@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) {
     				t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded))
     			}
     			// Check decoding works as expected
    -			err := dst.Decode(encoded[5:])
    +			err = dst.Decode(encoded[5:])
     			if err != nil {
     				if !tt.wantErr {
     					t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr)
    
  • pgproto3/gss_enc_request.go+2 2 modified
    @@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 4 byte message length.
    -func (src *GSSEncRequest) Encode(dst []byte) []byte {
    +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
     	dst = pgio.AppendInt32(dst, 8)
     	dst = pgio.AppendInt32(dst, gssEncReqNumber)
    -	return dst
    +	return dst, nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/gss_response.go+3 6 modified
    @@ -2,8 +2,6 @@ package pgproto3
     
     import (
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type GSSResponse struct {
    @@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
     	return nil
     }
     
    -func (g *GSSResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
    +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     	dst = append(dst, g.Data...)
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/no_data.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *NoData) Encode(dst []byte) []byte {
    -	return append(dst, 'n', 0, 0, 0, 4)
    +func (src *NoData) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'n', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/notice_response.go+4 2 modified
    @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *NoticeResponse) Encode(dst []byte) []byte {
    -	return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
    +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'N')
    +	dst = (*ErrorResponse)(src).appendFields(dst)
    +	return finishMessage(dst, sp)
     }
    
  • pgproto3/notification_response.go+3 9 modified
    @@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *NotificationResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'A')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'A')
     	dst = pgio.AppendUint32(dst, src.PID)
     	dst = append(dst, src.Channel...)
     	dst = append(dst, 0)
     	dst = append(dst, src.Payload...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/parameter_description.go+3 7 modified
    @@ -39,19 +39,15 @@ func (dst *ParameterDescription) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ParameterDescription) Encode(dst []byte) []byte {
    -	dst = append(dst, 't')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 't')
     
     	dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
     	for _, oid := range src.ParameterOIDs {
     		dst = pgio.AppendUint32(dst, oid)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/parameter_status.go+3 11 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type ParameterStatus struct {
    @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ParameterStatus) Encode(dst []byte) []byte {
    -	dst = append(dst, 'S')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    -
    +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'S')
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
     	dst = append(dst, src.Value...)
     	dst = append(dst, 0)
    -
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/parse_complete.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ParseComplete) Encode(dst []byte) []byte {
    -	return append(dst, '1', 0, 0, 0, 4)
    +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, '1', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/parse.go+3 7 modified
    @@ -52,10 +52,8 @@ func (dst *Parse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Parse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'P')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *Parse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'P')
     
     	dst = append(dst, src.Name...)
     	dst = append(dst, 0)
    @@ -67,9 +65,7 @@ func (src *Parse) Encode(dst []byte) []byte {
     		dst = pgio.AppendUint32(dst, oid)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/password_message.go+3 8 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type PasswordMessage struct {
    @@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *PasswordMessage) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
    -
    +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     	dst = append(dst, src.Password...)
     	dst = append(dst, 0)
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/pgproto3.go+27 1 modified
    @@ -4,8 +4,14 @@ import (
     	"encoding/hex"
     	"errors"
     	"fmt"
    +
    +	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
    +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
    +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
    +const maxMessageBodyLen = (0x3fffffff - 1)
    +
     // Message is the interface implemented by an object that can decode and encode
     // a particular PostgreSQL message.
     type Message interface {
    @@ -14,7 +20,7 @@ type Message interface {
     	Decode(data []byte) error
     
     	// Encode appends itself to dst and returns the new buffer.
    -	Encode(dst []byte) []byte
    +	Encode(dst []byte) ([]byte, error)
     }
     
     // FrontendMessage is a message sent by the frontend (i.e. the client).
    @@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
     	}
     	return nil, errors.New("unknown protocol representation")
     }
    +
    +// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
    +// dst. It returns the new buffer and the position of the message length placeholder.
    +func beginMessage(dst []byte, t byte) ([]byte, int) {
    +	dst = append(dst, t)
    +	sp := len(dst)
    +	dst = pgio.AppendInt32(dst, -1)
    +	return dst, sp
    +}
    +
    +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
    +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
    +func finishMessage(dst []byte, sp int) ([]byte, error) {
    +	messageBodyLen := len(dst[sp:])
    +	if messageBodyLen > maxMessageBodyLen {
    +		return nil, errors.New("message body too large")
    +	}
    +	pgio.SetInt32(dst[sp:], int32(messageBodyLen))
    +	return dst, nil
    +}
    
  • pgproto3/pgproto3_private_test.go+3 0 added
    @@ -0,0 +1,3 @@
    +package pgproto3
    +
    +const MaxMessageBodyLen = maxMessageBodyLen
    
  • pgproto3/portal_suspended.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *PortalSuspended) Encode(dst []byte) []byte {
    -	return append(dst, 's', 0, 0, 0, 4)
    +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 's', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/query.go+3 8 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"bytes"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type Query struct {
    @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Query) Encode(dst []byte) []byte {
    -	dst = append(dst, 'Q')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
    -
    +func (src *Query) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'Q')
     	dst = append(dst, src.String...)
     	dst = append(dst, 0)
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/query_test.go+20 0 added
    @@ -0,0 +1,20 @@
    +package pgproto3_test
    +
    +import (
    +	"testing"
    +
    +	"github.com/jackc/pgx/v5/pgproto3"
    +	"github.com/stretchr/testify/require"
    +)
    +
    +func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) {
    +	t.Parallel()
    +
    +	// Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string.
    +	_, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil)
    +	require.NoError(t, err)
    +
    +	// 1 byte too big
    +	_, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil)
    +	require.Error(t, err)
    +}
    
  • pgproto3/ready_for_query.go+2 2 modified
    @@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *ReadyForQuery) Encode(dst []byte) []byte {
    -	return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
    +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/row_description.go+3 7 modified
    @@ -99,10 +99,8 @@ func (dst *RowDescription) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *RowDescription) Encode(dst []byte) []byte {
    -	dst = append(dst, 'T')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'T')
     
     	dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
     	for _, fd := range src.Fields {
    @@ -117,9 +115,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
     		dst = pgio.AppendInt16(dst, fd.Format)
     	}
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/sasl_initial_response.go+3 7 modified
    @@ -39,20 +39,16 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *SASLInitialResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	sp := len(dst)
    -	dst = pgio.AppendInt32(dst, -1)
    +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     
     	dst = append(dst, []byte(src.AuthMechanism)...)
     	dst = append(dst, 0)
     
     	dst = pgio.AppendInt32(dst, int32(len(src.Data)))
     	dst = append(dst, src.Data...)
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/sasl_response.go+3 8 modified
    @@ -3,8 +3,6 @@ package pgproto3
     import (
     	"encoding/hex"
     	"encoding/json"
    -
    -	"github.com/jackc/pgx/v5/internal/pgio"
     )
     
     type SASLResponse struct {
    @@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *SASLResponse) Encode(dst []byte) []byte {
    -	dst = append(dst, 'p')
    -	dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
    -
    +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
    +	dst, sp := beginMessage(dst, 'p')
     	dst = append(dst, src.Data...)
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/ssl_request.go+2 2 modified
    @@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 4 byte message length.
    -func (src *SSLRequest) Encode(dst []byte) []byte {
    +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
     	dst = pgio.AppendInt32(dst, 8)
     	dst = pgio.AppendInt32(dst, sslRequestNumber)
    -	return dst
    +	return dst, nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/startup_message.go+2 4 modified
    @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *StartupMessage) Encode(dst []byte) []byte {
    +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
     	sp := len(dst)
     	dst = pgio.AppendInt32(dst, -1)
     
    @@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
     	}
     	dst = append(dst, 0)
     
    -	pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
    -
    -	return dst
    +	return finishMessage(dst, sp)
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/sync.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Sync) Encode(dst []byte) []byte {
    -	return append(dst, 'S', 0, 0, 0, 4)
    +func (src *Sync) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'S', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
  • pgproto3/terminate.go+2 2 modified
    @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
     }
     
     // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
    -func (src *Terminate) Encode(dst []byte) []byte {
    -	return append(dst, 'X', 0, 0, 0, 4)
    +func (src *Terminate) Encode(dst []byte) ([]byte, error) {
    +	return append(dst, 'X', 0, 0, 0, 4), nil
     }
     
     // MarshalJSON implements encoding/json.Marshaler.
    
f94eb0e2f967

Always wrap arguments in parentheses in the SQL sanitizer

https://github.com/jackc/pgxJack ChristensenFeb 24, 2024via ghsa
2 files changed · +14 20
  • internal/sanitize/sanitize.go+4 10 modified
    @@ -44,18 +44,8 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
     				str = "null"
     			case int64:
     				str = strconv.FormatInt(arg, 10)
    -				// Prevent SQL injection via Line Comment Creation
    -				// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
    -				if arg < 0 {
    -					str = "(" + str + ")"
    -				}
     			case float64:
    -				// Prevent SQL injection via Line Comment Creation
    -				// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
     				str = strconv.FormatFloat(arg, 'f', -1, 64)
    -				if arg < 0 {
    -					str = "(" + str + ")"
    -				}
     			case bool:
     				str = strconv.FormatBool(arg)
     			case []byte:
    @@ -68,6 +58,10 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) {
     				return "", fmt.Errorf("invalid arg type: %T", arg)
     			}
     			argUse[argIdx] = true
    +
    +			// Prevent SQL injection via Line Comment Creation
    +			// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
    +			str = "(" + str + ")"
     		default:
     			return "", fmt.Errorf("invalid Part type: %T", part)
     		}
    
  • internal/sanitize/sanitize_test.go+10 10 modified
    @@ -127,52 +127,52 @@ func TestQuerySanitize(t *testing.T) {
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select 42"}},
     			args:     []interface{}{},
    -			expected: `select 42`,
    +			expected: `select (42)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{int64(42)},
    -			expected: `select 42`,
    +			expected: `select (42)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{float64(1.23)},
    -			expected: `select 1.23`,
    +			expected: `select (1.23)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{true},
    -			expected: `select true`,
    +			expected: `select (true)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{[]byte{0, 1, 2, 3, 255}},
    -			expected: `select '\x00010203ff'`,
    +			expected: `select ('\x00010203ff')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{nil},
    -			expected: `select null`,
    +			expected: `select (null)`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{"foobar"},
    -			expected: `select 'foobar'`,
    +			expected: `select ('foobar')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{"foo'bar"},
    -			expected: `select 'foo''bar'`,
    +			expected: `select ('foo''bar')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
     			args:     []interface{}{`foo\'bar`},
    -			expected: `select 'foo\''bar'`,
    +			expected: `select ('foo\''bar')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"insert ", 1}},
     			args:     []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
    -			expected: `insert '2020-03-01 23:59:59.999999Z'`,
    +			expected: `insert ('2020-03-01 23:59:59.999999Z')`,
     		},
     		{
     			query:    sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

9

News mentions

0

No linked articles in our index yet.