1
0
Fork 0

Update tracing dependencies

This commit is contained in:
Ludovic Fernandez 2019-04-05 11:58:06 +02:00 committed by Traefiker Bot
parent 4919b638f9
commit ed12366d52
98 changed files with 3371 additions and 2808 deletions

View file

@ -30,11 +30,22 @@ const (
PROTOCOL_ERROR = 7
)
var defaultApplicationExceptionMessage = map[int32]string{
UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception",
UNKNOWN_METHOD: "unknown method",
INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type",
WRONG_METHOD_NAME: "wrong method name",
BAD_SEQUENCE_ID: "bad sequence ID",
MISSING_RESULT: "missing result",
INTERNAL_ERROR: "unknown internal error",
PROTOCOL_ERROR: "unknown protocol error",
}
// Application level Thrift exception
type TApplicationException interface {
TException
TypeId() int32
Read(iprot TProtocol) (TApplicationException, error)
Read(iprot TProtocol) error
Write(oprot TProtocol) error
}
@ -44,7 +55,10 @@ type tApplicationException struct {
}
func (e tApplicationException) Error() string {
return e.message
if e.message != "" {
return e.message
}
return defaultApplicationExceptionMessage[e.type_]
}
func NewTApplicationException(type_ int32, message string) TApplicationException {
@ -55,10 +69,11 @@ func (p *tApplicationException) TypeId() int32 {
return p.type_
}
func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, error) {
func (p *tApplicationException) Read(iprot TProtocol) error {
// TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
if err != nil {
return nil, err
return err
}
message := ""
@ -67,7 +82,7 @@ func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, er
for {
_, ttype, id, err := iprot.ReadFieldBegin()
if err != nil {
return nil, err
return err
}
if ttype == STOP {
break
@ -76,33 +91,40 @@ func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, er
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
return nil, err
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return nil, err
return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
return nil, err
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return nil, err
return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return nil, err
return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
return nil, err
return err
}
}
return NewTApplicationException(type_, message), iprot.ReadStructEnd()
if err := iprot.ReadStructEnd(); err != nil {
return err
}
p.message = message
p.type_ = type_
return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
@ -447,9 +448,6 @@ func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
if size < 0 {
return nil, invalidDataLength
}
if uint64(size) > p.trans.RemainingBytes() {
return nil, invalidDataLength
}
isize := int(size)
buf := make([]byte, isize)
@ -457,8 +455,8 @@ func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
return buf, NewTProtocolException(err)
}
func (p *TBinaryProtocol) Flush() (err error) {
return NewTProtocolException(p.trans.Flush())
func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
@ -480,9 +478,6 @@ func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
if uint64(size) > p.trans.RemainingBytes() {
return "", invalidDataLength
}
var (
buf bytes.Buffer

View file

@ -21,6 +21,7 @@ package thrift
import (
"bufio"
"context"
)
type TBufferedTransportFactory struct {
@ -32,8 +33,8 @@ type TBufferedTransport struct {
tp TTransport
}
func (p *TBufferedTransportFactory) GetTransport(trans TTransport) TTransport {
return NewTBufferedTransport(trans, p.size)
func (p *TBufferedTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return NewTBufferedTransport(trans, p.size), nil
}
func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory {
@ -78,12 +79,12 @@ func (p *TBufferedTransport) Write(b []byte) (int, error) {
return n, err
}
func (p *TBufferedTransport) Flush() error {
func (p *TBufferedTransport) Flush(ctx context.Context) error {
if err := p.ReadWriter.Flush(); err != nil {
p.ReadWriter.Writer.Reset(p.tp)
return err
}
return p.tp.Flush()
return p.tp.Flush(ctx)
}
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {

View file

@ -0,0 +1,85 @@
package thrift
import (
"context"
"fmt"
)
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) error
}
type TStandardClient struct {
seqId int32
iprot, oprot TProtocol
}
// TStandardClient implements TClient, and uses the standard message format for Thrift.
// It is not safe for concurrent use.
func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
return &TStandardClient{
iprot: inputProtocol,
oprot: outputProtocol,
}
}
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
return err
}
if err := args.Write(oprot); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
return err
}
return oprot.Flush(ctx)
}
func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
if err != nil {
return err
}
if method != rMethod {
return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
} else if seqId != rSeqId {
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(); err != nil {
return err
}
return &exception
} else if rTypeId != REPLY {
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
if err := result.Read(iprot); err != nil {
return err
}
return iprot.ReadMessageEnd()
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return err
}
// method is oneway
if result == nil {
return nil
}
return p.Recv(p.iprot, seqId, method, result)
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"encoding/binary"
"fmt"
"io"
@ -561,9 +562,6 @@ func (p *TCompactProtocol) ReadString() (value string, err error) {
if length < 0 {
return "", invalidDataLength
}
if uint64(length) > p.trans.RemainingBytes() {
return "", invalidDataLength
}
if length == 0 {
return "", nil
@ -590,17 +588,14 @@ func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
if length < 0 {
return nil, invalidDataLength
}
if uint64(length) > p.trans.RemainingBytes() {
return nil, invalidDataLength
}
buf := make([]byte, length)
_, e = io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(e)
}
func (p *TCompactProtocol) Flush() (err error) {
return NewTProtocolException(p.trans.Flush())
func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
@ -806,7 +801,7 @@ func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
case COMPACT_STRUCT:
return STRUCT, nil
}
return STOP, TException(fmt.Errorf("don't know what type: %s", t&0x0f))
return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.

View file

@ -19,12 +19,6 @@
package thrift
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(in, out TProtocol) (bool, TException)
}
import "context"
type TProcessorFunction interface {
Process(seqId int32, in, out TProtocol) (bool, TException)
}
var defaultCtx = context.Background()

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"log"
)
@ -258,8 +259,8 @@ func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
func (tdp *TDebugProtocol) Flush() (err error) {
err = tdp.Delegate.Flush()
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
return
}

View file

@ -22,6 +22,7 @@ package thrift
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
@ -48,11 +49,15 @@ func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
}
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) TTransport {
return NewTFramedTransportMaxLength(p.factory.GetTransport(base), p.maxLength)
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
return NewTFramedTransportMaxLength(tt, p.maxLength), nil
}
func NewTFramedTransport(transport TTransport) *TFramedTransport {
@ -131,21 +136,23 @@ func (p *TFramedTransport) WriteString(s string) (n int, err error) {
return p.buf.WriteString(s)
}
func (p *TFramedTransport) Flush() error {
func (p *TFramedTransport) Flush(ctx context.Context) error {
size := p.buf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
if size > 0 {
if n, err := p.buf.WriteTo(p.transport); err != nil {
print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n")
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
}
err = p.transport.Flush()
err = p.transport.Flush(ctx)
return NewTTransportExceptionFromError(err)
}
@ -164,4 +171,3 @@ func (p *TFramedTransport) readFrameHeader() (uint32, error) {
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.frameSize)
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
@ -46,27 +47,16 @@ type THttpClient struct {
type THttpClientTransportFactory struct {
options THttpClientOptions
url string
isPost bool
}
func (p *THttpClientTransportFactory) GetTransport(trans TTransport) TTransport {
func (p *THttpClientTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*THttpClient)
if ok && t.url != nil {
if t.requestBuffer != nil {
t2, _ := NewTHttpPostClientWithOptions(t.url.String(), p.options)
return t2
}
t2, _ := NewTHttpClientWithOptions(t.url.String(), p.options)
return t2
return NewTHttpClientWithOptions(t.url.String(), p.options)
}
}
if p.isPost {
s, _ := NewTHttpPostClientWithOptions(p.url, p.options)
return s
}
s, _ := NewTHttpClientWithOptions(p.url, p.options)
return s
return NewTHttpClientWithOptions(p.url, p.options)
}
type THttpClientOptions struct {
@ -79,39 +69,10 @@ func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory {
}
func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return &THttpClientTransportFactory{url: url, isPost: false, options: options}
}
func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpPostClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return &THttpClientTransportFactory{url: url, isPost: true, options: options}
return &THttpClientTransportFactory{url: url, options: options}
}
func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
}
response, err := http.Get(urlstr)
if err != nil {
return nil, err
}
client := options.Client
if client == nil {
client = DefaultHttpClient
}
httpHeader := map[string][]string{"Content-Type": []string{"application/x-thrift"}}
return &THttpClient{client: client, response: response, url: parsedURL, header: httpHeader}, nil
}
func NewTHttpClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}
func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
@ -121,12 +82,12 @@ func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (T
if client == nil {
client = DefaultHttpClient
}
httpHeader := map[string][]string{"Content-Type": []string{"application/x-thrift"}}
httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}}
return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil
}
func NewTHttpPostClient(urlstr string) (TTransport, error) {
return NewTHttpPostClientWithOptions(urlstr, THttpClientOptions{})
func NewTHttpClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}
// Set the HTTP Header for this specific Thrift Transport
@ -221,7 +182,7 @@ func (p *THttpClient) WriteString(s string) (n int, err error) {
return p.requestBuffer.WriteString(s)
}
func (p *THttpClient) Flush() error {
func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
@ -230,6 +191,9 @@ func (p *THttpClient) Flush() error {
return NewTTransportExceptionFromError(err)
}
req.Header = p.header
if ctx != nil {
req = req.WithContext(ctx)
}
response, err := p.client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)
@ -256,3 +220,23 @@ func (p *THttpClient) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}
// Deprecated: Use NewTHttpClientTransportFactory instead.
func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
// Deprecated: Use NewTHttpClientTransportFactoryWithOptions instead.
func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, options)
}
// Deprecated: Use NewTHttpClientWithOptions instead.
func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, options)
}
// Deprecated: Use NewTHttpClient instead.
func NewTHttpPostClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}

View file

@ -19,16 +19,45 @@
package thrift
import "net/http"
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
func NewThriftHandlerFunc(processor TProcessor,
inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
return gz(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/x-thrift")
transport := NewStreamTransport(r.Body, w)
processor.Process(inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
processor.Process(r.Context(), inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
})
}
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bufio"
"context"
"io"
)
@ -38,38 +39,38 @@ type StreamTransportFactory struct {
isReadWriter bool
}
func (p *StreamTransportFactory) GetTransport(trans TTransport) TTransport {
func (p *StreamTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*StreamTransport)
if ok {
if t.isReadWriter {
return NewStreamTransportRW(t.Reader.(io.ReadWriter))
return NewStreamTransportRW(t.Reader.(io.ReadWriter)), nil
}
if t.Reader != nil && t.Writer != nil {
return NewStreamTransport(t.Reader, t.Writer)
return NewStreamTransport(t.Reader, t.Writer), nil
}
if t.Reader != nil && t.Writer == nil {
return NewStreamTransportR(t.Reader)
return NewStreamTransportR(t.Reader), nil
}
if t.Reader == nil && t.Writer != nil {
return NewStreamTransportW(t.Writer)
return NewStreamTransportW(t.Writer), nil
}
return &StreamTransport{}
return &StreamTransport{}, nil
}
}
if p.isReadWriter {
return NewStreamTransportRW(p.Reader.(io.ReadWriter))
return NewStreamTransportRW(p.Reader.(io.ReadWriter)), nil
}
if p.Reader != nil && p.Writer != nil {
return NewStreamTransport(p.Reader, p.Writer)
return NewStreamTransport(p.Reader, p.Writer), nil
}
if p.Reader != nil && p.Writer == nil {
return NewStreamTransportR(p.Reader)
return NewStreamTransportR(p.Reader), nil
}
if p.Reader == nil && p.Writer != nil {
return NewStreamTransportW(p.Writer)
return NewStreamTransportW(p.Writer), nil
}
return &StreamTransport{}
return &StreamTransport{}, nil
}
func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory {
@ -138,7 +139,7 @@ func (p *StreamTransport) Close() error {
}
// Flushes the underlying output stream if not null.
func (p *StreamTransport) Flush() error {
func (p *StreamTransport) Flush(ctx context.Context) error {
if p.Writer == nil {
return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream")
}
@ -209,6 +210,5 @@ func (p *StreamTransport) WriteString(s string) (n int, err error) {
func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the thruth is, we just don't know unless framed is used
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"encoding/base64"
"fmt"
)
@ -438,10 +439,10 @@ func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) Flush() (err error) {
func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
err = p.writer.Flush()
if err == nil {
err = p.trans.Flush()
err = p.trans.Flush(ctx)
}
return NewTProtocolException(err)
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
)
// Memory buffer-based implementation of the TTransport interface.
@ -33,14 +34,14 @@ type TMemoryBufferTransportFactory struct {
size int
}
func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) TTransport {
func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*TMemoryBuffer)
if ok && t.size > 0 {
return NewTMemoryBufferLen(t.size)
return NewTMemoryBufferLen(t.size), nil
}
}
return NewTMemoryBufferLen(p.size)
return NewTMemoryBufferLen(p.size), nil
}
func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory {
@ -70,7 +71,7 @@ func (p *TMemoryBuffer) Close() error {
}
// Flushing a memory buffer is a no-op
func (p *TMemoryBuffer) Flush() error {
func (p *TMemoryBuffer) Flush(ctx context.Context) error {
return nil
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"fmt"
"strings"
)
@ -127,7 +128,7 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProces
t.serviceProcessorMap[name] = processor
}
func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) {
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin()
if err != nil {
return false, err
@ -140,7 +141,7 @@ func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) {
if len(v) != 2 {
if t.DefaultProcessor != nil {
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(smb, out)
return t.DefaultProcessor.Process(ctx, smb, out)
}
return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
}
@ -149,7 +150,7 @@ func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) {
return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(smb, out)
return actualProcessor.Process(ctx, smb, out)
}
//Protocol that use stored message for ReadMessageBegin

View file

@ -19,6 +19,18 @@
package thrift
import "context"
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(ctx context.Context, in, out TProtocol) (bool, TException)
}
type TProcessorFunction interface {
Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}
// The default processor factory just returns a singleton
// instance.
type TProcessorFactory interface {

View file

@ -20,7 +20,9 @@
package thrift
import (
"context"
"errors"
"fmt"
)
const (
@ -73,7 +75,7 @@ type TProtocol interface {
ReadBinary() (value []byte, err error)
Skip(fieldType TType) (err error)
Flush() (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
}
@ -88,9 +90,9 @@ func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) {
// Skips over the next data element from the provided input TProtocol object.
func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return NewTProtocolExceptionWithType( DEPTH_LIMIT, errors.New("Depth limit exceeded"))
if maxDepth <= 0 {
return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded"))
}
switch fieldType {
@ -170,6 +172,8 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
}
}
return self.ReadListEnd()
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
return nil
}

View file

@ -60,7 +60,7 @@ func NewTProtocolException(err error) TProtocolException {
if err == nil {
return nil
}
if e,ok := err.(TProtocolException); ok {
if e, ok := err.(TProtocolException); ok {
return e
}
if _, ok := err.(base64.CorruptInputError); ok {
@ -75,4 +75,3 @@ func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
}
return &tProtocolException{errType, err.Error()}
}

View file

@ -66,4 +66,3 @@ func writeByte(w io.Writer, c byte) error {
_, err := w.Write(v[0:1])
return err
}

View file

@ -19,6 +19,10 @@
package thrift
import (
"context"
)
type TSerializer struct {
Transport *TMemoryBuffer
Protocol TProtocol
@ -38,35 +42,35 @@ func NewTSerializer() *TSerializer {
protocol}
}
func (t *TSerializer) WriteString(msg TStruct) (s string, err error) {
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(); err != nil {
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(); err != nil {
if err = t.Transport.Flush(ctx); err != nil {
return
}
return t.Transport.String(), nil
}
func (t *TSerializer) Write(msg TStruct) (b []byte, err error) {
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(); err != nil {
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(); err != nil {
if err = t.Transport.Flush(ctx); err != nil {
return
}

View file

@ -47,7 +47,14 @@ func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*T
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil
}
// Creates a TServerSocket from a net.Addr
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
}
func (p *TServerSocket) Listen() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return nil
}
@ -67,10 +74,15 @@ func (p *TServerSocket) Accept() (TTransport, error) {
if interrupted {
return nil, errTransportInterrupted
}
if p.listener == nil {
p.mu.Lock()
listener := p.listener
p.mu.Unlock()
if listener == nil {
return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
}
conn, err := p.listener.Accept()
conn, err := listener.Accept()
if err != nil {
return nil, NewTTransportExceptionFromError(err)
}
@ -84,6 +96,8 @@ func (p *TServerSocket) IsListening() bool {
// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
@ -103,20 +117,21 @@ func (p *TServerSocket) Addr() net.Addr {
}
func (p *TServerSocket) Close() error {
defer func() {
p.listener = nil
}()
var err error
p.mu.Lock()
if p.IsListening() {
return p.listener.Close()
err = p.listener.Close()
p.listener = nil
}
return nil
p.mu.Unlock()
return err
}
func (p *TServerSocket) Interrupt() error {
p.mu.Lock()
p.interrupted = true
p.Close()
p.mu.Unlock()
p.Close()
return nil
}

View file

@ -22,6 +22,7 @@ package thrift
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
@ -552,7 +553,7 @@ func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) {
return v, p.ParsePostValue()
}
func (p *TSimpleJSONProtocol) Flush() (err error) {
func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.writer.Flush())
}
@ -1064,7 +1065,7 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
for _, char := range line {
switch char {
default:
e := fmt.Errorf("Expecting end of list \"]\", but found: \"", line, "\"")
e := fmt.Errorf("Expecting end of list \"]\", but found: \"%v\"", line)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]):
break

View file

@ -23,11 +23,18 @@ import (
"log"
"runtime/debug"
"sync"
"sync/atomic"
)
// Simple, non-concurrent server for testing.
/*
* This is not a typical TSimpleServer as it is not blocked after accept a socket.
* It is more like a TThreadedServer that can handle different connections in different goroutines.
* This will work if golang user implements a conn-pool like thing in client side.
*/
type TSimpleServer struct {
quit chan struct{}
closed int32
wg sync.WaitGroup
mu sync.Mutex
processorFactory TProcessorFactory
serverTransport TServerTransport
@ -87,7 +94,6 @@ func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTranspor
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
quit: make(chan struct{}, 1),
}
}
@ -119,23 +125,37 @@ func (p *TSimpleServer) Listen() error {
return p.serverTransport.Listen()
}
func (p *TSimpleServer) innerAccept() (int32, error) {
client, err := p.serverTransport.Accept()
p.mu.Lock()
defer p.mu.Unlock()
closed := atomic.LoadInt32(&p.closed)
if closed != 0 {
return closed, nil
}
if err != nil {
return 0, err
}
if client != nil {
p.wg.Add(1)
go func() {
defer p.wg.Done()
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
}
}()
}
return 0, nil
}
func (p *TSimpleServer) AcceptLoop() error {
for {
client, err := p.serverTransport.Accept()
closed, err := p.innerAccept()
if err != nil {
select {
case <-p.quit:
return nil
default:
}
return err
}
if client != nil {
go func() {
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
}
}()
if closed != 0 {
return nil
}
}
}
@ -149,21 +169,28 @@ func (p *TSimpleServer) Serve() error {
return nil
}
var once sync.Once
func (p *TSimpleServer) Stop() error {
q := func() {
p.quit <- struct{}{}
p.serverTransport.Interrupt()
p.mu.Lock()
defer p.mu.Unlock()
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
once.Do(q)
atomic.StoreInt32(&p.closed, 1)
p.serverTransport.Interrupt()
p.wg.Wait()
return nil
}
func (p *TSimpleServer) processRequests(client TTransport) error {
processor := p.processorFactory.GetProcessor(client)
inputTransport := p.inputTransportFactory.GetTransport(client)
outputTransport := p.outputTransportFactory.GetTransport(client)
inputTransport, err := p.inputTransportFactory.GetTransport(client)
if err != nil {
return err
}
outputTransport, err := p.outputTransportFactory.GetTransport(client)
if err != nil {
return err
}
inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport)
defer func() {
@ -171,6 +198,7 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
log.Printf("panic in processor: %s: %s", e, debug.Stack())
}
}()
if inputTransport != nil {
defer inputTransport.Close()
}
@ -178,17 +206,20 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
defer outputTransport.Close()
}
for {
ok, err := processor.Process(inputProtocol, outputProtocol)
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
ok, err := processor.Process(defaultCtx, inputProtocol, outputProtocol)
if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
return nil
} else if err != nil {
log.Printf("error processing request: %s", err)
return err
}
if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD {
continue
}
if !ok {
if !ok {
break
}
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"net"
"time"
)
@ -148,7 +149,7 @@ func (p *TSocket) Write(buf []byte) (int, error) {
return p.conn.Write(buf)
}
func (p *TSocket) Flush() error {
func (p *TSocket) Flush(ctx context.Context) error {
return nil
}
@ -161,6 +162,5 @@ func (p *TSocket) Interrupt() error {
func (p *TSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the thruth is, we just don't know unless framed is used
}

View file

@ -20,9 +20,9 @@
package thrift
import (
"crypto/tls"
"net"
"time"
"crypto/tls"
)
type TSSLServerSocket struct {
@ -38,6 +38,9 @@ func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket,
}
func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) {
if cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
addr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"crypto/tls"
"net"
"time"
@ -48,6 +49,9 @@ func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a tls Configuration and a timeout as a time.Duration
func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) {
if cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil
}
@ -87,7 +91,8 @@ func (p *TSSLSocket) Open() error {
// If we have a hostname, we need to pass the hostname to tls.Dial for
// certificate hostname checks.
if p.hostPort != "" {
if p.conn, err = tls.Dial("tcp", p.hostPort, p.cfg); err != nil {
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
} else {
@ -103,7 +108,8 @@ func (p *TSSLSocket) Open() error {
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
if p.conn, err = tls.Dial(p.addr.Network(), p.addr.String(), p.cfg); err != nil {
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
}
@ -153,7 +159,7 @@ func (p *TSSLSocket) Write(buf []byte) (int, error) {
return p.conn.Write(buf)
}
func (p *TSSLSocket) Flush() error {
func (p *TSSLSocket) Flush(ctx context.Context) error {
return nil
}
@ -166,6 +172,5 @@ func (p *TSSLSocket) Interrupt() error {
func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
return maxSize // the thruth is, we just don't know unless framed is used
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"errors"
"io"
)
@ -30,15 +31,18 @@ type Flusher interface {
Flush() (err error)
}
type ContextFlusher interface {
Flush(ctx context.Context) (err error)
}
type ReadSizeProvider interface {
RemainingBytes() (num_bytes uint64)
}
// Encapsulates the I/O layer
type TTransport interface {
io.ReadWriteCloser
Flusher
ContextFlusher
ReadSizeProvider
// Opens the transport for communication
@ -52,7 +56,6 @@ type stringWriter interface {
WriteString(s string) (n int, err error)
}
// This is "enchanced" transport with extra capabilities. You need to use one of these
// to construct protocol.
// Notably, TSocket does not implement this interface, and it is always a mistake to use
@ -62,7 +65,6 @@ type TRichTransport interface {
io.ByteReader
io.ByteWriter
stringWriter
Flusher
ContextFlusher
ReadSizeProvider
}

View file

@ -24,14 +24,14 @@ package thrift
// a ServerTransport and then may want to mutate them (i.e. create
// a BufferedTransport from the underlying base transport)
type TTransportFactory interface {
GetTransport(trans TTransport) TTransport
GetTransport(trans TTransport) (TTransport, error)
}
type tTransportFactory struct{}
// Return a wrapped instance of the base Transport.
func (p *tTransportFactory) GetTransport(trans TTransport) TTransport {
return trans
func (p *tTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return trans, nil
}
func NewTTransportFactory() TTransportFactory {

View file

@ -21,13 +21,15 @@ package thrift
import (
"compress/zlib"
"context"
"io"
"log"
)
// TZlibTransportFactory is a factory for TZlibTransport instances
type TZlibTransportFactory struct {
level int
level int
factory TTransportFactory
}
// TZlibTransport is a TTransport implementation that makes use of zlib compression.
@ -38,14 +40,27 @@ type TZlibTransport struct {
}
// GetTransport constructs a new instance of NewTZlibTransport
func (p *TZlibTransportFactory) GetTransport(trans TTransport) TTransport {
t, _ := NewTZlibTransport(trans, p.level)
return t
func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if p.factory != nil {
// wrap other factory
var err error
trans, err = p.factory.GetTransport(trans)
if err != nil {
return nil, err
}
}
return NewTZlibTransport(trans, p.level)
}
// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory
func NewTZlibTransportFactory(level int) *TZlibTransportFactory {
return &TZlibTransportFactory{level: level}
return &TZlibTransportFactory{level: level, factory: nil}
}
// NewTZlibTransportFactory constructs a new instance of TZlibTransportFactory
// as a wrapper over existing transport factory
func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory {
return &TZlibTransportFactory{level: level, factory: factory}
}
// NewTZlibTransport constructs a new instance of TZlibTransport
@ -77,11 +92,11 @@ func (z *TZlibTransport) Close() error {
}
// Flush flushes the writer and its underlying transport.
func (z *TZlibTransport) Flush() error {
func (z *TZlibTransport) Flush(ctx context.Context) error {
if err := z.writer.Flush(); err != nil {
return err
}
return z.transport.Flush()
return z.transport.Flush(ctx)
}
// IsOpen returns true if the transport is open