Bumping k8s dependencies to 1.13

This commit is contained in:
Cheng Xing
2018-11-16 14:08:25 -08:00
parent 305407125c
commit b4c0b68ec7
8002 changed files with 884099 additions and 276228 deletions

2
vendor/google.golang.org/grpc/.github/lock.yml generated vendored Normal file
View File

@@ -0,0 +1,2 @@
daysUntilLock: 180
lockComment: false

View File

@@ -1,23 +1,39 @@
language: go
go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
matrix:
include:
- go: 1.9.x
- go: 1.11.x
env: VET=1 GO111MODULE=on
- go: 1.11.x
env: RACE=1 GO111MODULE=on
- go: 1.11.x
env: RUN386=1
- go: 1.11.x
env: GRPC_GO_RETRY=on
- go: 1.10.x
- go: 1.9.x
- go: 1.9.x
env: GAE=1
- go: 1.8.x
- go: 1.6.x
go_import_path: google.golang.org/grpc
before_install:
- if [[ "$TRAVIS_GO_VERSION" = 1.9* && "$GOARCH" != "386" ]]; then ./vet.sh -install || exit 1; fi
- if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi
- if [[ -n "${RUN386}" ]]; then export GOARCH=386; fi
- if [[ "${TRAVIS_EVENT_TYPE}" = "cron" && -z "${RUN386}" ]]; then RACE=1; fi
- if [[ "${TRAVIS_EVENT_TYPE}" != "cron" ]]; then VET_SKIP_PROTO=1; fi
install:
- try3() { eval "$*" || eval "$*" || eval "$*"; }
- try3 'if [[ "${GO111MODULE}" = "on" ]]; then go mod download; else make testdeps; fi'
- if [[ "${GAE}" = 1 ]]; then source ./install_gae.sh; make testappenginedeps; fi
- if [[ "${VET}" = 1 ]]; then ./vet.sh -install; fi
script:
- if [[ -n "$RUN386" ]]; then export GOARCH=386; fi
- if [[ "$TRAVIS_GO_VERSION" = 1.9* && "$GOARCH" != "386" ]]; then ./vet.sh || exit 1; fi
- make test || exit 1
- if [[ "$GOARCH" != "386" ]]; then make testrace; fi
- set -e
- if [[ "${VET}" = 1 ]]; then ./vet.sh; fi
- if [[ "${GAE}" = 1 ]]; then make testappengine; exit 0; fi
- if [[ "${RACE}" = 1 ]]; then make testrace; exit 0; fi
- make test

View File

@@ -27,6 +27,10 @@ How to get your contributions merged smoothly and quickly.
- Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change).
- **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on.
- `make all` to test everything, OR
- `make vet` to catch vet errors
- `make test` to run the tests
- `make testrace` to run tests in race mode
- Exceptions to the rules can be made if there's a compelling reason for doing so.

View File

@@ -0,0 +1,33 @@
# Concurrency
In general, gRPC-go provides a concurrency-friendly API. What follows are some
guidelines.
## Clients
A [ClientConn][client-conn] can safely be accessed concurrently. Using
[helloworld][helloworld] as an example, one could share the `ClientConn` across
multiple goroutines to create multiple `GreeterClient` types. In this case, RPCs
would be sent in parallel.
## Streams
When using streams, one must take care to avoid calling either `SendMsg` or
`RecvMsg` multiple times against the same [Stream][stream] from different
goroutines. In other words, it's safe to have a goroutine calling `SendMsg` and
another goroutine calling `RecvMsg` on the same stream at the same time. But it
is not safe to call `SendMsg` on the same stream in different goroutines, or to
call `RecvMsg` on the same stream in different goroutines.
## Servers
Each RPC handler attached to a registered server will be invoked in its own
goroutine. For example, [SayHello][say-hello] will be invoked in its own
goroutine. The same is true for service handlers for streaming RPCs, as seen
in the route guide example [here][route-guide-stream].
[helloworld]: https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_client/main.go#L43
[client-conn]: https://godoc.org/google.golang.org/grpc#ClientConn
[stream]: https://godoc.org/google.golang.org/grpc#Stream
[say-hello]: https://github.com/grpc/grpc-go/blob/master/examples/helloworld/greeter_server/main.go#L41
[route-guide-stream]: https://github.com/grpc/grpc-go/blob/master/examples/route_guide/server/server.go#L126

View File

@@ -21,6 +21,43 @@ server := grpc.NewServer(grpc.Creds(creds))
server.Serve(lis)
```
# OAuth2
For an example of how to configure client and server to use OAuth2 tokens, see
[here](https://github.com/grpc/grpc-go/blob/master/examples/oauth/).
## Validating a token on the server
Clients may use
[metadata.MD](https://godoc.org/google.golang.org/grpc/metadata#MD)
to store tokens and other authentication-related data. To gain access to the
`metadata.MD` object, a server may use
[metadata.FromIncomingContext](https://godoc.org/google.golang.org/grpc/metadata#FromIncomingContext).
With a reference to `metadata.MD` on the server, one needs to simply lookup the
`authorization` key. Note, all keys stored within `metadata.MD` are normalized
to lowercase. See [here](https://godoc.org/google.golang.org/grpc/metadata#New).
It is possible to configure token validation for all RPCs using an interceptor.
A server may configure either a
[grpc.UnaryInterceptor](https://godoc.org/google.golang.org/grpc#UnaryInterceptor)
or a
[grpc.StreamInterceptor](https://godoc.org/google.golang.org/grpc#StreamInterceptor).
## Adding a token to all outgoing client RPCs
To send an OAuth2 token with each RPC, a client may configure the
`grpc.DialOption`
[grpc.WithPerRPCCredentials](https://godoc.org/google.golang.org/grpc#WithPerRPCCredentials).
Alternatively, a client may also use the `grpc.CallOption`
[grpc.PerRPCCredentials](https://godoc.org/google.golang.org/grpc#PerRPCCredentials)
on each invocation of an RPC.
To create a `credentials.PerRPCCredentials`, use
[oauth.NewOauthAccess](https://godoc.org/google.golang.org/grpc/credentials/oauth#NewOauthAccess).
Note, the OAuth2 implementation of `grpc.PerRPCCredentials` requires a client to use
[grpc.WithTransportCredentials](https://godoc.org/google.golang.org/grpc#WithTransportCredentials)
to prevent any insecure transmission of tokens.
# Authenticating with Google
## Google Compute Engine (GCE)

View File

@@ -0,0 +1,46 @@
# Keepalive
gRPC sends http2 pings on the transport to detect if the connection is down. If
the ping is not acknowledged by the other side within a certain period, the
connection will be close. Note that pings are only necessary when there's no
activity on the connection.
For how to configure keepalive, see
https://godoc.org/google.golang.org/grpc/keepalive for the options.
## What should I set?
It should be sufficient for most users to set [client
parameters](https://godoc.org/google.golang.org/grpc/keepalive) as a [dial
option](https://godoc.org/google.golang.org/grpc#WithKeepaliveParams).
## What will happen?
(The behavior described here is specific for gRPC-go, it might be slightly
different in other languages.)
When there's no activity on a connection (note that an ongoing stream results in
__no activity__ when there's no message being sent), after `Time`, a ping will
be sent by the client and the server will send a ping ack when it gets the ping.
Client will wait for `Timeout`, and check if there's any activity on the
connection during this period (a ping ack is an activity).
## What about server side?
Server has similar `Time` and `Timeout` settings as client. Server can also
configure connection max-age. See [server
parameters](https://godoc.org/google.golang.org/grpc/keepalive#ServerParameters)
for details.
### Enforcement policy
[Enforcement
policy](https://godoc.org/google.golang.org/grpc/keepalive#ServerParameters) is
a special setting on server side to protect server from malicious or misbehaving
clients.
Server sends GOAWAY with ENHANCE_YOUR_CALM and close the connection when bad
behaviors are detected:
- Client sends too frequent pings
- Client sends pings when there's no stream and this is disallowed by server
config

View File

@@ -0,0 +1,49 @@
# Log Levels
This document describes the different log levels supported by the grpc-go
library, and under what conditions they should be used.
### Info
Info messages are for informational purposes and may aid in the debugging of
applications or the gRPC library.
Examples:
- The name resolver received an update.
- The balancer updated its picker.
- Significant gRPC state is changing.
At verbosity of 0 (the default), any single info message should not be output
more than once every 5 minutes under normal operation.
### Warning
Warning messages indicate problems that are non-fatal for the application, but
could lead to unexpected behavior or subsequent errors.
Examples:
- Resolver could not resolve target name.
- Error received while connecting to a server.
- Lost or corrupt connection with remote endpoint.
### Error
Error messages represent errors in the usage of gRPC that cannot be returned to
the application as errors, or internal gRPC-Go errors that are recoverable.
Internal errors are detected during gRPC tests and will result in test failures.
Examples:
- Invalid arguments passed to a function that cannot return an error.
- An internal error that cannot be returned or would be inappropriate to return
to the user.
### Fatal
Fatal errors are severe internal errors that are unrecoverable. These lead
directly to panics, and are avoided as much as possible.
Example:
- Internal invariant was violated.
- User attempted an action that cannot return an error gracefully, but would
lead to an invalid state if performed.

15
vendor/google.golang.org/grpc/Documentation/proxy.md generated vendored Normal file
View File

@@ -0,0 +1,15 @@
# Proxy
HTTP CONNECT proxies are supported by default in gRPC. The proxy address can be
specified by the environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or
the lowercase versions thereof).
## Custom proxy
Currently, proxy support is implemented in the default dialer. It does one more
handshake (a CONNECT handshake in the case of HTTP CONNECT proxy) on the
connection before giving it to gRPC.
If the default proxy doesn't work for you, replace the default dialer with your
custom proxy dialer. This can be done using
[`WithDialer`](https://godoc.org/google.golang.org/grpc#WithDialer).

View File

@@ -0,0 +1,68 @@
# RPC Errors
All service method handlers should return `nil` or errors from the
`status.Status` type. Clients have direct access to the errors.
Upon encountering an error, a gRPC server method handler should create a
`status.Status`. In typical usage, one would use [status.New][new-status]
passing in an appropriate [codes.Code][code] as well as a description of the
error to produce a `status.Status`. Calling [status.Err][status-err] converts
the `status.Status` type into an `error`. As a convenience method, there is also
[status.Error][status-error] which obviates the conversion step. Compare:
```
st := status.New(codes.NotFound, "some description")
err := st.Err()
// vs.
err := status.Error(codes.NotFound, "some description")
```
## Adding additional details to errors
In some cases, it may be necessary to add details for a particular error on the
server side. The [status.WithDetails][with-details] method exists for this
purpose. Clients may then read those details by first converting the plain
`error` type back to a [status.Status][status] and then using
[status.Details][details].
## Example
The [example][example] demonstrates the API discussed above and shows how to add
information about rate limits to the error message using `status.Status`.
To run the example, first start the server:
```
$ go run examples/rpc_errors/server/main.go
```
In a separate session, run the client:
```
$ go run examples/rpc_errors/client/main.go
```
On the first run of the client, all is well:
```
2018/03/12 19:39:33 Greeting: Hello world
```
Upon running the client a second time, the client exceeds the rate limit and
receives an error with details:
```
2018/03/19 16:42:01 Quota failure: violations:<subject:"name:world" description:"Limit one greeting per person" >
exit status 1
```
[status]: https://godoc.org/google.golang.org/grpc/status#Status
[new-status]: https://godoc.org/google.golang.org/grpc/status#New
[code]: https://godoc.org/google.golang.org/grpc/codes#Code
[with-details]: https://godoc.org/google.golang.org/grpc/status#Status.WithDetails
[details]: https://godoc.org/google.golang.org/grpc/status#Status.Details
[status-err]: https://godoc.org/google.golang.org/grpc/status#Status.Err
[status-error]: https://godoc.org/google.golang.org/grpc/status#Error
[example]: https://github.com/grpc/grpc-go/blob/master/examples/rpc_errors

View File

@@ -1,20 +1,14 @@
all: test testrace
deps:
go get -d -v google.golang.org/grpc/...
updatedeps:
go get -d -v -u -f google.golang.org/grpc/...
testdeps:
go get -d -v -t google.golang.org/grpc/...
updatetestdeps:
go get -d -v -t -u -f google.golang.org/grpc/...
all: vet test testrace testappengine
build: deps
go build google.golang.org/grpc/...
clean:
go clean -i google.golang.org/grpc/...
deps:
go get -d -v google.golang.org/grpc/...
proto:
@ if ! which protoc > /dev/null; then \
echo "error: protoc not installed" >&2; \
@@ -23,23 +17,44 @@ proto:
go generate google.golang.org/grpc/...
test: testdeps
go test -cpu 1,4 -timeout 5m google.golang.org/grpc/...
go test -cpu 1,4 -timeout 7m google.golang.org/grpc/...
testappengine: testappenginedeps
goapp test -cpu 1,4 -timeout 7m google.golang.org/grpc/...
testappenginedeps:
goapp get -d -v -t -tags 'appengine appenginevm' google.golang.org/grpc/...
testdeps:
go get -d -v -t google.golang.org/grpc/...
testrace: testdeps
go test -race -cpu 1,4 -timeout 7m google.golang.org/grpc/...
clean:
go clean -i google.golang.org/grpc/...
updatedeps:
go get -d -v -u -f google.golang.org/grpc/...
updatetestdeps:
go get -d -v -t -u -f google.golang.org/grpc/...
vet: vetdeps
./vet.sh
vetdeps:
./vet.sh -install
.PHONY: \
all \
deps \
updatedeps \
testdeps \
updatetestdeps \
build \
clean \
deps \
proto \
test \
testappengine \
testappenginedeps \
testdeps \
testrace \
clean \
coverage
updatedeps \
updatetestdeps \
vet \
vetdeps

View File

@@ -16,8 +16,7 @@ $ go get -u google.golang.org/grpc
Prerequisites
-------------
This requires Go 1.6 or later. Go 1.7 will be required as of the next gRPC-Go
release (1.8).
This requires Go 1.6 or later. Go 1.7 will be required soon.
Constraints
-----------
@@ -44,3 +43,25 @@ Please update proto package, gRPC package and rebuild the proto files:
- `go get -u github.com/golang/protobuf/{proto,protoc-gen-go}`
- `go get -u google.golang.org/grpc`
- `protoc --go_out=plugins=grpc:. *.proto`
#### How to turn on logging
The default logger is controlled by the environment variables. Turn everything
on by setting:
```
GRPC_GO_LOG_VERBOSITY_LEVEL=99 GRPC_GO_LOG_SEVERITY_LEVEL=info
```
#### The RPC failed with error `"code = Unavailable desc = transport is closing"`
This error means the connection the RPC is using was closed, and there are many
possible reasons, including:
1. mis-configured transport credentials, connection failed on handshaking
1. bytes disrupted, possibly by a proxy in between
1. server shutdown
It can be tricky to debug this because the error happens on the client side but
the root cause of the connection being closed is on the server side. Turn on
logging on __both client and server__, and see if there are any transport
errors.

View File

@@ -16,81 +16,23 @@
*
*/
// See internal/backoff package for the backoff implementation. This file is
// kept for the exported types and API backward compatility.
package grpc
import (
"math/rand"
"time"
)
// DefaultBackoffConfig uses values specified for backoff in
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
var DefaultBackoffConfig = BackoffConfig{
MaxDelay: 120 * time.Second,
baseDelay: 1.0 * time.Second,
factor: 1.6,
jitter: 0.2,
}
// backoffStrategy defines the methodology for backing off after a grpc
// connection failure.
//
// This is unexported until the gRPC project decides whether or not to allow
// alternative backoff strategies. Once a decision is made, this type and its
// method may be exported.
type backoffStrategy interface {
// backoff returns the amount of time to wait before the next retry given
// the number of consecutive failures.
backoff(retries int) time.Duration
MaxDelay: 120 * time.Second,
}
// BackoffConfig defines the parameters for the default gRPC backoff strategy.
type BackoffConfig struct {
// MaxDelay is the upper bound of backoff delay.
MaxDelay time.Duration
// TODO(stevvooe): The following fields are not exported, as allowing
// changes would violate the current gRPC specification for backoff. If
// gRPC decides to allow more interesting backoff strategies, these fields
// may be opened up in the future.
// baseDelay is the amount of time to wait before retrying after the first
// failure.
baseDelay time.Duration
// factor is applied to the backoff after each retry.
factor float64
// jitter provides a range to randomize backoff delays.
jitter float64
}
func setDefaults(bc *BackoffConfig) {
md := bc.MaxDelay
*bc = DefaultBackoffConfig
if md > 0 {
bc.MaxDelay = md
}
}
func (bc BackoffConfig) backoff(retries int) time.Duration {
if retries == 0 {
return bc.baseDelay
}
backoff, max := float64(bc.baseDelay), float64(bc.MaxDelay)
for backoff < max && retries > 0 {
backoff *= bc.factor
retries--
}
if backoff > max {
backoff = max
}
// Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep.
backoff *= 1 + bc.jitter*(rand.Float64()*2-1)
if backoff < 0 {
return 0
}
return time.Duration(backoff)
}

View File

@@ -19,7 +19,6 @@
package grpc
import (
"fmt"
"net"
"sync"
@@ -32,7 +31,8 @@ import (
)
// Address represents a server the client connects to.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
//
// Deprecated: please use package balancer.
type Address struct {
// Addr is the server address on which a connection will be established.
Addr string
@@ -42,6 +42,8 @@ type Address struct {
}
// BalancerConfig specifies the configurations for Balancer.
//
// Deprecated: please use package balancer.
type BalancerConfig struct {
// DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations
@@ -54,7 +56,8 @@ type BalancerConfig struct {
}
// BalancerGetOptions configures a Get call.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
//
// Deprecated: please use package balancer.
type BalancerGetOptions struct {
// BlockingWait specifies whether Get should block when there is no
// connected address.
@@ -62,7 +65,8 @@ type BalancerGetOptions struct {
}
// Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
//
// Deprecated: please use package balancer.
type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will
@@ -113,28 +117,10 @@ type Balancer interface {
Close() error
}
// downErr implements net.Error. It is constructed by gRPC internals and passed to the down
// call of Balancer.
type downErr struct {
timeout bool
temporary bool
desc string
}
func (e downErr) Error() string { return e.desc }
func (e downErr) Timeout() bool { return e.timeout }
func (e downErr) Temporary() bool { return e.temporary }
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
return downErr{
timeout: timeout,
temporary: temporary,
desc: fmt.Sprintf(format, a...),
}
}
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
// the name resolution updates and updates the addresses available correspondingly.
//
// Deprecated: please use package balancer/roundrobin.
func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r}
}
@@ -403,7 +389,3 @@ func (rr *roundRobin) Close() error {
type pickFirst struct {
*roundRobin
}
func pickFirstBalancerV1(r naming.Resolver) Balancer {
return &pickFirst{&roundRobin{r: r}}
}

View File

@@ -28,6 +28,7 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
@@ -36,9 +37,12 @@ var (
m = make(map[string]Builder)
)
// Register registers the balancer builder to the balancer map.
// b.Name (lowercased) will be used as the name registered with
// this builder.
// Register registers the balancer builder to the balancer map. b.Name
// (lowercased) will be used as the name registered with this builder.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect.
func Register(b Builder) {
m[strings.ToLower(b.Name())] = b
}
@@ -85,7 +89,12 @@ type SubConn interface {
}
// NewSubConnOptions contains options to create new SubConn.
type NewSubConnOptions struct{}
type NewSubConnOptions struct {
// CredsBundle is the credentials bundle that will be used in the created
// SubConn. If it's nil, the original creds from grpc DialOptions will be
// used.
CredsBundle credentials.Bundle
}
// ClientConn represents a gRPC ClientConn.
//
@@ -122,10 +131,14 @@ type BuildOptions struct {
// use to dial to a remote load balancer server. The Balancer implementations
// can ignore this if it does not need to talk to another party securely.
DialCreds credentials.TransportCredentials
// CredsBundle is the credentials bundle that the Balancer can use.
CredsBundle credentials.Bundle
// Dialer is the custom dialer the Balancer implementation can use to dial
// to a remote load balancer server. The Balancer implementations
// can ignore this if it doesn't need to talk to remote balancer.
Dialer func(context.Context, string) (net.Conn, error)
// ChannelzParentID is the entity parent's channelz unique identification number.
ChannelzParentID int64
}
// Builder creates a balancer.
@@ -138,12 +151,21 @@ type Builder interface {
}
// PickOptions contains addition information for the Pick operation.
type PickOptions struct{}
type PickOptions struct {
// FullMethodName is the method name that NewClientStream() is called
// with. The canonical format is /service/Method.
FullMethodName string
// Header contains the metadata from the RPC's client header. The metadata
// should not be modified; make a copy first if needed.
Header metadata.MD
}
// DoneInfo contains additional information for done.
type DoneInfo struct {
// Err is the rpc error the RPC finished with. It could be nil.
Err error
// Trailer contains the metadata from the RPC's trailer, if present.
Trailer metadata.MD
// BytesSent indicates if any bytes have been sent to the server.
BytesSent bool
// BytesReceived indicates if any byte has been received from the server.
@@ -160,7 +182,7 @@ var (
)
// Picker is used by gRPC to pick a SubConn to send an RPC.
// Balancer is expected to generate a new picker from its snapshot everytime its
// Balancer is expected to generate a new picker from its snapshot every time its
// internal state has changed.
//
// The pickers used by gRPC can be updated by ClientConn.UpdateBalancerState().
@@ -221,3 +243,45 @@ type Balancer interface {
// ClientConn.RemoveSubConn for its existing SubConns.
Close()
}
// ConnectivityStateEvaluator takes the connectivity states of multiple SubConns
// and returns one aggregated connectivity state.
//
// It's not thread safe.
type ConnectivityStateEvaluator struct {
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// RecordTransition records state change happening in subConn and based on that
// it evaluates what aggregated state should be.
//
// - If at least one SubConn in Ready, the aggregated state is Ready;
// - Else if at least one SubConn in Connecting, the aggregated state is Connecting;
// - Else the aggregated state is TransientFailure.
//
// Idle and Shutdown are not considered.
func (cse *ConnectivityStateEvaluator) RecordTransition(oldState, newState connectivity.State) connectivity.State {
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}

View File

@@ -146,7 +146,6 @@ func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectiv
}
b.cc.UpdateBalancerState(b.state, b.picker)
return
}
// Close is a nop because base balancer doesn't have internal state to clean up,

View File

@@ -0,0 +1,839 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/lb/v1/load_balancer.proto
package grpc_lb_v1 // import "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import duration "github.com/golang/protobuf/ptypes/duration"
import timestamp "github.com/golang/protobuf/ptypes/timestamp"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type LoadBalanceRequest struct {
// Types that are valid to be assigned to LoadBalanceRequestType:
// *LoadBalanceRequest_InitialRequest
// *LoadBalanceRequest_ClientStats
LoadBalanceRequestType isLoadBalanceRequest_LoadBalanceRequestType `protobuf_oneof:"load_balance_request_type"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoadBalanceRequest) Reset() { *m = LoadBalanceRequest{} }
func (m *LoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceRequest) ProtoMessage() {}
func (*LoadBalanceRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{0}
}
func (m *LoadBalanceRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_LoadBalanceRequest.Unmarshal(m, b)
}
func (m *LoadBalanceRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_LoadBalanceRequest.Marshal(b, m, deterministic)
}
func (dst *LoadBalanceRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_LoadBalanceRequest.Merge(dst, src)
}
func (m *LoadBalanceRequest) XXX_Size() int {
return xxx_messageInfo_LoadBalanceRequest.Size(m)
}
func (m *LoadBalanceRequest) XXX_DiscardUnknown() {
xxx_messageInfo_LoadBalanceRequest.DiscardUnknown(m)
}
var xxx_messageInfo_LoadBalanceRequest proto.InternalMessageInfo
type isLoadBalanceRequest_LoadBalanceRequestType interface {
isLoadBalanceRequest_LoadBalanceRequestType()
}
type LoadBalanceRequest_InitialRequest struct {
InitialRequest *InitialLoadBalanceRequest `protobuf:"bytes,1,opt,name=initial_request,json=initialRequest,proto3,oneof"`
}
type LoadBalanceRequest_ClientStats struct {
ClientStats *ClientStats `protobuf:"bytes,2,opt,name=client_stats,json=clientStats,proto3,oneof"`
}
func (*LoadBalanceRequest_InitialRequest) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (*LoadBalanceRequest_ClientStats) isLoadBalanceRequest_LoadBalanceRequestType() {}
func (m *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType {
if m != nil {
return m.LoadBalanceRequestType
}
return nil
}
func (m *LoadBalanceRequest) GetInitialRequest() *InitialLoadBalanceRequest {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_InitialRequest); ok {
return x.InitialRequest
}
return nil
}
func (m *LoadBalanceRequest) GetClientStats() *ClientStats {
if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_ClientStats); ok {
return x.ClientStats
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceRequest_OneofMarshaler, _LoadBalanceRequest_OneofUnmarshaler, _LoadBalanceRequest_OneofSizer, []interface{}{
(*LoadBalanceRequest_InitialRequest)(nil),
(*LoadBalanceRequest_ClientStats)(nil),
}
}
func _LoadBalanceRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialRequest); err != nil {
return err
}
case *LoadBalanceRequest_ClientStats:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ClientStats); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceRequest.LoadBalanceRequestType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceRequest)
switch tag {
case 1: // load_balance_request_type.initial_request
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceRequest)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_InitialRequest{msg}
return true, err
case 2: // load_balance_request_type.client_stats
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ClientStats)
err := b.DecodeMessage(msg)
m.LoadBalanceRequestType = &LoadBalanceRequest_ClientStats{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceRequest_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceRequest)
// load_balance_request_type
switch x := m.LoadBalanceRequestType.(type) {
case *LoadBalanceRequest_InitialRequest:
s := proto.Size(x.InitialRequest)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceRequest_ClientStats:
s := proto.Size(x.ClientStats)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceRequest struct {
// The name of the load balanced service (e.g., service.googleapis.com). Its
// length should be less than 256 bytes.
// The name might include a port number. How to handle the port number is up
// to the balancer.
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *InitialLoadBalanceRequest) Reset() { *m = InitialLoadBalanceRequest{} }
func (m *InitialLoadBalanceRequest) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceRequest) ProtoMessage() {}
func (*InitialLoadBalanceRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{1}
}
func (m *InitialLoadBalanceRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_InitialLoadBalanceRequest.Unmarshal(m, b)
}
func (m *InitialLoadBalanceRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_InitialLoadBalanceRequest.Marshal(b, m, deterministic)
}
func (dst *InitialLoadBalanceRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_InitialLoadBalanceRequest.Merge(dst, src)
}
func (m *InitialLoadBalanceRequest) XXX_Size() int {
return xxx_messageInfo_InitialLoadBalanceRequest.Size(m)
}
func (m *InitialLoadBalanceRequest) XXX_DiscardUnknown() {
xxx_messageInfo_InitialLoadBalanceRequest.DiscardUnknown(m)
}
var xxx_messageInfo_InitialLoadBalanceRequest proto.InternalMessageInfo
func (m *InitialLoadBalanceRequest) GetName() string {
if m != nil {
return m.Name
}
return ""
}
// Contains the number of calls finished for a particular load balance token.
type ClientStatsPerToken struct {
// See Server.load_balance_token.
LoadBalanceToken string `protobuf:"bytes,1,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// The total number of RPCs that finished associated with the token.
NumCalls int64 `protobuf:"varint,2,opt,name=num_calls,json=numCalls,proto3" json:"num_calls,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStatsPerToken) Reset() { *m = ClientStatsPerToken{} }
func (m *ClientStatsPerToken) String() string { return proto.CompactTextString(m) }
func (*ClientStatsPerToken) ProtoMessage() {}
func (*ClientStatsPerToken) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{2}
}
func (m *ClientStatsPerToken) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStatsPerToken.Unmarshal(m, b)
}
func (m *ClientStatsPerToken) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStatsPerToken.Marshal(b, m, deterministic)
}
func (dst *ClientStatsPerToken) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStatsPerToken.Merge(dst, src)
}
func (m *ClientStatsPerToken) XXX_Size() int {
return xxx_messageInfo_ClientStatsPerToken.Size(m)
}
func (m *ClientStatsPerToken) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStatsPerToken.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStatsPerToken proto.InternalMessageInfo
func (m *ClientStatsPerToken) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *ClientStatsPerToken) GetNumCalls() int64 {
if m != nil {
return m.NumCalls
}
return 0
}
// Contains client level statistics that are useful to load balancing. Each
// count except the timestamp should be reset to zero after reporting the stats.
type ClientStats struct {
// The timestamp of generating the report.
Timestamp *timestamp.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// The total number of RPCs that started.
NumCallsStarted int64 `protobuf:"varint,2,opt,name=num_calls_started,json=numCallsStarted,proto3" json:"num_calls_started,omitempty"`
// The total number of RPCs that finished.
NumCallsFinished int64 `protobuf:"varint,3,opt,name=num_calls_finished,json=numCallsFinished,proto3" json:"num_calls_finished,omitempty"`
// The total number of RPCs that failed to reach a server except dropped RPCs.
NumCallsFinishedWithClientFailedToSend int64 `protobuf:"varint,6,opt,name=num_calls_finished_with_client_failed_to_send,json=numCallsFinishedWithClientFailedToSend,proto3" json:"num_calls_finished_with_client_failed_to_send,omitempty"`
// The total number of RPCs that finished and are known to have been received
// by a server.
NumCallsFinishedKnownReceived int64 `protobuf:"varint,7,opt,name=num_calls_finished_known_received,json=numCallsFinishedKnownReceived,proto3" json:"num_calls_finished_known_received,omitempty"`
// The list of dropped calls.
CallsFinishedWithDrop []*ClientStatsPerToken `protobuf:"bytes,8,rep,name=calls_finished_with_drop,json=callsFinishedWithDrop,proto3" json:"calls_finished_with_drop,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{3}
}
func (m *ClientStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStats.Unmarshal(m, b)
}
func (m *ClientStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStats.Marshal(b, m, deterministic)
}
func (dst *ClientStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStats.Merge(dst, src)
}
func (m *ClientStats) XXX_Size() int {
return xxx_messageInfo_ClientStats.Size(m)
}
func (m *ClientStats) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStats.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStats proto.InternalMessageInfo
func (m *ClientStats) GetTimestamp() *timestamp.Timestamp {
if m != nil {
return m.Timestamp
}
return nil
}
func (m *ClientStats) GetNumCallsStarted() int64 {
if m != nil {
return m.NumCallsStarted
}
return 0
}
func (m *ClientStats) GetNumCallsFinished() int64 {
if m != nil {
return m.NumCallsFinished
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedWithClientFailedToSend() int64 {
if m != nil {
return m.NumCallsFinishedWithClientFailedToSend
}
return 0
}
func (m *ClientStats) GetNumCallsFinishedKnownReceived() int64 {
if m != nil {
return m.NumCallsFinishedKnownReceived
}
return 0
}
func (m *ClientStats) GetCallsFinishedWithDrop() []*ClientStatsPerToken {
if m != nil {
return m.CallsFinishedWithDrop
}
return nil
}
type LoadBalanceResponse struct {
// Types that are valid to be assigned to LoadBalanceResponseType:
// *LoadBalanceResponse_InitialResponse
// *LoadBalanceResponse_ServerList
LoadBalanceResponseType isLoadBalanceResponse_LoadBalanceResponseType `protobuf_oneof:"load_balance_response_type"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *LoadBalanceResponse) Reset() { *m = LoadBalanceResponse{} }
func (m *LoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*LoadBalanceResponse) ProtoMessage() {}
func (*LoadBalanceResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{4}
}
func (m *LoadBalanceResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_LoadBalanceResponse.Unmarshal(m, b)
}
func (m *LoadBalanceResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_LoadBalanceResponse.Marshal(b, m, deterministic)
}
func (dst *LoadBalanceResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_LoadBalanceResponse.Merge(dst, src)
}
func (m *LoadBalanceResponse) XXX_Size() int {
return xxx_messageInfo_LoadBalanceResponse.Size(m)
}
func (m *LoadBalanceResponse) XXX_DiscardUnknown() {
xxx_messageInfo_LoadBalanceResponse.DiscardUnknown(m)
}
var xxx_messageInfo_LoadBalanceResponse proto.InternalMessageInfo
type isLoadBalanceResponse_LoadBalanceResponseType interface {
isLoadBalanceResponse_LoadBalanceResponseType()
}
type LoadBalanceResponse_InitialResponse struct {
InitialResponse *InitialLoadBalanceResponse `protobuf:"bytes,1,opt,name=initial_response,json=initialResponse,proto3,oneof"`
}
type LoadBalanceResponse_ServerList struct {
ServerList *ServerList `protobuf:"bytes,2,opt,name=server_list,json=serverList,proto3,oneof"`
}
func (*LoadBalanceResponse_InitialResponse) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (*LoadBalanceResponse_ServerList) isLoadBalanceResponse_LoadBalanceResponseType() {}
func (m *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType {
if m != nil {
return m.LoadBalanceResponseType
}
return nil
}
func (m *LoadBalanceResponse) GetInitialResponse() *InitialLoadBalanceResponse {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_InitialResponse); ok {
return x.InitialResponse
}
return nil
}
func (m *LoadBalanceResponse) GetServerList() *ServerList {
if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_ServerList); ok {
return x.ServerList
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*LoadBalanceResponse) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _LoadBalanceResponse_OneofMarshaler, _LoadBalanceResponse_OneofUnmarshaler, _LoadBalanceResponse_OneofSizer, []interface{}{
(*LoadBalanceResponse_InitialResponse)(nil),
(*LoadBalanceResponse_ServerList)(nil),
}
}
func _LoadBalanceResponse_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
b.EncodeVarint(1<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.InitialResponse); err != nil {
return err
}
case *LoadBalanceResponse_ServerList:
b.EncodeVarint(2<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ServerList); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("LoadBalanceResponse.LoadBalanceResponseType has unexpected type %T", x)
}
return nil
}
func _LoadBalanceResponse_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*LoadBalanceResponse)
switch tag {
case 1: // load_balance_response_type.initial_response
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(InitialLoadBalanceResponse)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_InitialResponse{msg}
return true, err
case 2: // load_balance_response_type.server_list
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ServerList)
err := b.DecodeMessage(msg)
m.LoadBalanceResponseType = &LoadBalanceResponse_ServerList{msg}
return true, err
default:
return false, nil
}
}
func _LoadBalanceResponse_OneofSizer(msg proto.Message) (n int) {
m := msg.(*LoadBalanceResponse)
// load_balance_response_type
switch x := m.LoadBalanceResponseType.(type) {
case *LoadBalanceResponse_InitialResponse:
s := proto.Size(x.InitialResponse)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *LoadBalanceResponse_ServerList:
s := proto.Size(x.ServerList)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type InitialLoadBalanceResponse struct {
// This is an application layer redirect that indicates the client should use
// the specified server for load balancing. When this field is non-empty in
// the response, the client should open a separate connection to the
// load_balancer_delegate and call the BalanceLoad method. Its length should
// be less than 64 bytes.
LoadBalancerDelegate string `protobuf:"bytes,1,opt,name=load_balancer_delegate,json=loadBalancerDelegate,proto3" json:"load_balancer_delegate,omitempty"`
// This interval defines how often the client should send the client stats
// to the load balancer. Stats should only be reported when the duration is
// positive.
ClientStatsReportInterval *duration.Duration `protobuf:"bytes,2,opt,name=client_stats_report_interval,json=clientStatsReportInterval,proto3" json:"client_stats_report_interval,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *InitialLoadBalanceResponse) Reset() { *m = InitialLoadBalanceResponse{} }
func (m *InitialLoadBalanceResponse) String() string { return proto.CompactTextString(m) }
func (*InitialLoadBalanceResponse) ProtoMessage() {}
func (*InitialLoadBalanceResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{5}
}
func (m *InitialLoadBalanceResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_InitialLoadBalanceResponse.Unmarshal(m, b)
}
func (m *InitialLoadBalanceResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_InitialLoadBalanceResponse.Marshal(b, m, deterministic)
}
func (dst *InitialLoadBalanceResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_InitialLoadBalanceResponse.Merge(dst, src)
}
func (m *InitialLoadBalanceResponse) XXX_Size() int {
return xxx_messageInfo_InitialLoadBalanceResponse.Size(m)
}
func (m *InitialLoadBalanceResponse) XXX_DiscardUnknown() {
xxx_messageInfo_InitialLoadBalanceResponse.DiscardUnknown(m)
}
var xxx_messageInfo_InitialLoadBalanceResponse proto.InternalMessageInfo
func (m *InitialLoadBalanceResponse) GetLoadBalancerDelegate() string {
if m != nil {
return m.LoadBalancerDelegate
}
return ""
}
func (m *InitialLoadBalanceResponse) GetClientStatsReportInterval() *duration.Duration {
if m != nil {
return m.ClientStatsReportInterval
}
return nil
}
type ServerList struct {
// Contains a list of servers selected by the load balancer. The list will
// be updated when server resolutions change or as needed to balance load
// across more servers. The client should consume the server list in order
// unless instructed otherwise via the client_config.
Servers []*Server `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerList) Reset() { *m = ServerList{} }
func (m *ServerList) String() string { return proto.CompactTextString(m) }
func (*ServerList) ProtoMessage() {}
func (*ServerList) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{6}
}
func (m *ServerList) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerList.Unmarshal(m, b)
}
func (m *ServerList) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerList.Marshal(b, m, deterministic)
}
func (dst *ServerList) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerList.Merge(dst, src)
}
func (m *ServerList) XXX_Size() int {
return xxx_messageInfo_ServerList.Size(m)
}
func (m *ServerList) XXX_DiscardUnknown() {
xxx_messageInfo_ServerList.DiscardUnknown(m)
}
var xxx_messageInfo_ServerList proto.InternalMessageInfo
func (m *ServerList) GetServers() []*Server {
if m != nil {
return m.Servers
}
return nil
}
// Contains server information. When the drop field is not true, use the other
// fields.
type Server struct {
// A resolved address for the server, serialized in network-byte-order. It may
// either be an IPv4 or IPv6 address.
IpAddress []byte `protobuf:"bytes,1,opt,name=ip_address,json=ipAddress,proto3" json:"ip_address,omitempty"`
// A resolved port number for the server.
Port int32 `protobuf:"varint,2,opt,name=port,proto3" json:"port,omitempty"`
// An opaque but printable token for load reporting. The client must include
// the token of the picked server into the initial metadata when it starts a
// call to that server. The token is used by the server to verify the request
// and to allow the server to report load to the gRPC LB system. The token is
// also used in client stats for reporting dropped calls.
//
// Its length can be variable but must be less than 50 bytes.
LoadBalanceToken string `protobuf:"bytes,3,opt,name=load_balance_token,json=loadBalanceToken,proto3" json:"load_balance_token,omitempty"`
// Indicates whether this particular request should be dropped by the client.
// If the request is dropped, there will be a corresponding entry in
// ClientStats.calls_finished_with_drop.
Drop bool `protobuf:"varint,4,opt,name=drop,proto3" json:"drop,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Server) Reset() { *m = Server{} }
func (m *Server) String() string { return proto.CompactTextString(m) }
func (*Server) ProtoMessage() {}
func (*Server) Descriptor() ([]byte, []int) {
return fileDescriptor_load_balancer_12026aec3f0251ba, []int{7}
}
func (m *Server) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Server.Unmarshal(m, b)
}
func (m *Server) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Server.Marshal(b, m, deterministic)
}
func (dst *Server) XXX_Merge(src proto.Message) {
xxx_messageInfo_Server.Merge(dst, src)
}
func (m *Server) XXX_Size() int {
return xxx_messageInfo_Server.Size(m)
}
func (m *Server) XXX_DiscardUnknown() {
xxx_messageInfo_Server.DiscardUnknown(m)
}
var xxx_messageInfo_Server proto.InternalMessageInfo
func (m *Server) GetIpAddress() []byte {
if m != nil {
return m.IpAddress
}
return nil
}
func (m *Server) GetPort() int32 {
if m != nil {
return m.Port
}
return 0
}
func (m *Server) GetLoadBalanceToken() string {
if m != nil {
return m.LoadBalanceToken
}
return ""
}
func (m *Server) GetDrop() bool {
if m != nil {
return m.Drop
}
return false
}
func init() {
proto.RegisterType((*LoadBalanceRequest)(nil), "grpc.lb.v1.LoadBalanceRequest")
proto.RegisterType((*InitialLoadBalanceRequest)(nil), "grpc.lb.v1.InitialLoadBalanceRequest")
proto.RegisterType((*ClientStatsPerToken)(nil), "grpc.lb.v1.ClientStatsPerToken")
proto.RegisterType((*ClientStats)(nil), "grpc.lb.v1.ClientStats")
proto.RegisterType((*LoadBalanceResponse)(nil), "grpc.lb.v1.LoadBalanceResponse")
proto.RegisterType((*InitialLoadBalanceResponse)(nil), "grpc.lb.v1.InitialLoadBalanceResponse")
proto.RegisterType((*ServerList)(nil), "grpc.lb.v1.ServerList")
proto.RegisterType((*Server)(nil), "grpc.lb.v1.Server")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// LoadBalancerClient is the client API for LoadBalancer service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type LoadBalancerClient interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error)
}
type loadBalancerClient struct {
cc *grpc.ClientConn
}
func NewLoadBalancerClient(cc *grpc.ClientConn) LoadBalancerClient {
return &loadBalancerClient{cc}
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error) {
stream, err := c.cc.NewStream(ctx, &_LoadBalancer_serviceDesc.Streams[0], "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
x := &loadBalancerBalanceLoadClient{stream}
return x, nil
}
type LoadBalancer_BalanceLoadClient interface {
Send(*LoadBalanceRequest) error
Recv() (*LoadBalanceResponse, error)
grpc.ClientStream
}
type loadBalancerBalanceLoadClient struct {
grpc.ClientStream
}
func (x *loadBalancerBalanceLoadClient) Send(m *LoadBalanceRequest) error {
return x.ClientStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadClient) Recv() (*LoadBalanceResponse, error) {
m := new(LoadBalanceResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// LoadBalancerServer is the server API for LoadBalancer service.
type LoadBalancerServer interface {
// Bidirectional rpc to get a list of servers.
BalanceLoad(LoadBalancer_BalanceLoadServer) error
}
func RegisterLoadBalancerServer(s *grpc.Server, srv LoadBalancerServer) {
s.RegisterService(&_LoadBalancer_serviceDesc, srv)
}
func _LoadBalancer_BalanceLoad_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(LoadBalancerServer).BalanceLoad(&loadBalancerBalanceLoadServer{stream})
}
type LoadBalancer_BalanceLoadServer interface {
Send(*LoadBalanceResponse) error
Recv() (*LoadBalanceRequest, error)
grpc.ServerStream
}
type loadBalancerBalanceLoadServer struct {
grpc.ServerStream
}
func (x *loadBalancerBalanceLoadServer) Send(m *LoadBalanceResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *loadBalancerBalanceLoadServer) Recv() (*LoadBalanceRequest, error) {
m := new(LoadBalanceRequest)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
var _LoadBalancer_serviceDesc = grpc.ServiceDesc{
ServiceName: "grpc.lb.v1.LoadBalancer",
HandlerType: (*LoadBalancerServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "BalanceLoad",
Handler: _LoadBalancer_BalanceLoad_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "grpc/lb/v1/load_balancer.proto",
}
func init() {
proto.RegisterFile("grpc/lb/v1/load_balancer.proto", fileDescriptor_load_balancer_12026aec3f0251ba)
}
var fileDescriptor_load_balancer_12026aec3f0251ba = []byte{
// 752 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xdd, 0x6e, 0x23, 0x35,
0x14, 0xee, 0x90, 0x69, 0x36, 0x39, 0x29, 0x34, 0xeb, 0x85, 0x65, 0x92, 0xdd, 0x6d, 0x4b, 0x24,
0x56, 0x11, 0x2a, 0x13, 0x52, 0xb8, 0x00, 0x89, 0x0b, 0x48, 0xab, 0x2a, 0x2d, 0xbd, 0x88, 0x9c,
0x4a, 0x45, 0x95, 0x90, 0x99, 0xc9, 0xb8, 0xa9, 0x55, 0xc7, 0x1e, 0x3c, 0x4e, 0x2a, 0xae, 0x79,
0x1f, 0xc4, 0x2b, 0x20, 0x5e, 0x0c, 0x8d, 0xed, 0x49, 0xa6, 0x49, 0xa3, 0xbd, 0xca, 0xf8, 0x9c,
0xcf, 0xdf, 0xf9, 0xfd, 0x1c, 0x38, 0x98, 0xaa, 0x74, 0xd2, 0xe3, 0x71, 0x6f, 0xd1, 0xef, 0x71,
0x19, 0x25, 0x24, 0x8e, 0x78, 0x24, 0x26, 0x54, 0x85, 0xa9, 0x92, 0x5a, 0x22, 0xc8, 0xfd, 0x21,
0x8f, 0xc3, 0x45, 0xbf, 0x7d, 0x30, 0x95, 0x72, 0xca, 0x69, 0xcf, 0x78, 0xe2, 0xf9, 0x5d, 0x2f,
0x99, 0xab, 0x48, 0x33, 0x29, 0x2c, 0xb6, 0x7d, 0xb8, 0xee, 0xd7, 0x6c, 0x46, 0x33, 0x1d, 0xcd,
0x52, 0x0b, 0xe8, 0xfc, 0xeb, 0x01, 0xba, 0x92, 0x51, 0x32, 0xb0, 0x31, 0x30, 0xfd, 0x63, 0x4e,
0x33, 0x8d, 0x46, 0xb0, 0xcf, 0x04, 0xd3, 0x2c, 0xe2, 0x44, 0x59, 0x53, 0xe0, 0x1d, 0x79, 0xdd,
0xc6, 0xc9, 0x97, 0xe1, 0x2a, 0x7a, 0x78, 0x61, 0x21, 0x9b, 0xf7, 0x87, 0x3b, 0xf8, 0x13, 0x77,
0xbf, 0x60, 0xfc, 0x11, 0xf6, 0x26, 0x9c, 0x51, 0xa1, 0x49, 0xa6, 0x23, 0x9d, 0x05, 0x1f, 0x19,
0xba, 0xcf, 0xcb, 0x74, 0xa7, 0xc6, 0x3f, 0xce, 0xdd, 0xc3, 0x1d, 0xdc, 0x98, 0xac, 0x8e, 0x83,
0x37, 0xd0, 0x2a, 0xb7, 0xa2, 0x48, 0x8a, 0xe8, 0x3f, 0x53, 0xda, 0xe9, 0x41, 0x6b, 0x6b, 0x26,
0x08, 0x81, 0x2f, 0xa2, 0x19, 0x35, 0xe9, 0xd7, 0xb1, 0xf9, 0xee, 0xfc, 0x0e, 0xaf, 0x4a, 0xb1,
0x46, 0x54, 0x5d, 0xcb, 0x07, 0x2a, 0xd0, 0x31, 0xa0, 0x27, 0x41, 0x74, 0x6e, 0x75, 0x17, 0x9b,
0x7c, 0x45, 0x6d, 0xd1, 0x6f, 0xa0, 0x2e, 0xe6, 0x33, 0x32, 0x89, 0x38, 0xb7, 0xd5, 0x54, 0x70,
0x4d, 0xcc, 0x67, 0xa7, 0xf9, 0xb9, 0xf3, 0x4f, 0x05, 0x1a, 0xa5, 0x10, 0xe8, 0x7b, 0xa8, 0x2f,
0x3b, 0xef, 0x3a, 0xd9, 0x0e, 0xed, 0x6c, 0xc2, 0x62, 0x36, 0xe1, 0x75, 0x81, 0xc0, 0x2b, 0x30,
0xfa, 0x0a, 0x5e, 0x2e, 0xc3, 0xe4, 0xad, 0x53, 0x9a, 0x26, 0x2e, 0xdc, 0x7e, 0x11, 0x6e, 0x6c,
0xcd, 0x79, 0x01, 0x2b, 0xec, 0x1d, 0x13, 0x2c, 0xbb, 0xa7, 0x49, 0x50, 0x31, 0xe0, 0x66, 0x01,
0x3e, 0x77, 0x76, 0xf4, 0x1b, 0x7c, 0xbd, 0x89, 0x26, 0x8f, 0x4c, 0xdf, 0x13, 0x37, 0xa9, 0xbb,
0x88, 0x71, 0x9a, 0x10, 0x2d, 0x49, 0x46, 0x45, 0x12, 0x54, 0x0d, 0xd1, 0xfb, 0x75, 0xa2, 0x1b,
0xa6, 0xef, 0x6d, 0xad, 0xe7, 0x06, 0x7f, 0x2d, 0xc7, 0x54, 0x24, 0x68, 0x08, 0x5f, 0x3c, 0x43,
0xff, 0x20, 0xe4, 0xa3, 0x20, 0x8a, 0x4e, 0x28, 0x5b, 0xd0, 0x24, 0x78, 0x61, 0x28, 0xdf, 0xad,
0x53, 0xfe, 0x92, 0xa3, 0xb0, 0x03, 0xa1, 0x5f, 0x21, 0x78, 0x2e, 0xc9, 0x44, 0xc9, 0x34, 0xa8,
0x1d, 0x55, 0xba, 0x8d, 0x93, 0xc3, 0x2d, 0x6b, 0x54, 0x8c, 0x16, 0x7f, 0x36, 0x59, 0xcf, 0xf8,
0x4c, 0xc9, 0xf4, 0xd2, 0xaf, 0xf9, 0xcd, 0xdd, 0x4b, 0xbf, 0xb6, 0xdb, 0xac, 0x76, 0xfe, 0xf3,
0xe0, 0xd5, 0x93, 0xfd, 0xc9, 0x52, 0x29, 0x32, 0x8a, 0xc6, 0xd0, 0x5c, 0x49, 0xc1, 0xda, 0xdc,
0x04, 0xdf, 0x7f, 0x48, 0x0b, 0x16, 0x3d, 0xdc, 0xc1, 0xfb, 0x4b, 0x31, 0x38, 0xd2, 0x1f, 0xa0,
0x91, 0x51, 0xb5, 0xa0, 0x8a, 0x70, 0x96, 0x69, 0x27, 0x86, 0xd7, 0x65, 0xbe, 0xb1, 0x71, 0x5f,
0x31, 0x23, 0x26, 0xc8, 0x96, 0xa7, 0xc1, 0x5b, 0x68, 0xaf, 0x49, 0xc1, 0x72, 0x5a, 0x2d, 0xfc,
0xed, 0x41, 0x7b, 0x7b, 0x2a, 0xe8, 0x3b, 0x78, 0xfd, 0xe4, 0x49, 0x21, 0x09, 0xe5, 0x74, 0x1a,
0xe9, 0x42, 0x1f, 0x9f, 0x96, 0xd6, 0x5c, 0x9d, 0x39, 0x1f, 0xba, 0x85, 0xb7, 0x65, 0xed, 0x12,
0x45, 0x53, 0xa9, 0x34, 0x61, 0x42, 0x53, 0xb5, 0x88, 0xb8, 0x4b, 0xbf, 0xb5, 0xb1, 0xd0, 0x67,
0xee, 0x31, 0xc2, 0xad, 0x92, 0x96, 0xb1, 0xb9, 0x7c, 0xe1, 0xee, 0x76, 0x7e, 0x02, 0x58, 0x95,
0x8a, 0x8e, 0xe1, 0x85, 0x2d, 0x35, 0x0b, 0x3c, 0x33, 0x59, 0xb4, 0xd9, 0x13, 0x5c, 0x40, 0x2e,
0xfd, 0x5a, 0xa5, 0xe9, 0x77, 0xfe, 0xf2, 0xa0, 0x6a, 0x3d, 0xe8, 0x1d, 0x00, 0x4b, 0x49, 0x94,
0x24, 0x8a, 0x66, 0x99, 0x29, 0x69, 0x0f, 0xd7, 0x59, 0xfa, 0xb3, 0x35, 0xe4, 0x6f, 0x41, 0x1e,
0xdb, 0xe4, 0xbb, 0x8b, 0xcd, 0xf7, 0x16, 0xd1, 0x57, 0xb6, 0x88, 0x1e, 0x81, 0x6f, 0xd6, 0xce,
0x3f, 0xf2, 0xba, 0x35, 0x6c, 0xbe, 0xed, 0xfa, 0x9c, 0xc4, 0xb0, 0x57, 0x6a, 0xb8, 0x42, 0x18,
0x1a, 0xee, 0x3b, 0x37, 0xa3, 0x83, 0x72, 0x1d, 0x9b, 0xcf, 0x54, 0xfb, 0x70, 0xab, 0xdf, 0x4e,
0xae, 0xeb, 0x7d, 0xe3, 0x0d, 0x6e, 0xe0, 0x63, 0x26, 0x4b, 0xc0, 0xc1, 0xcb, 0x72, 0xc8, 0x51,
0xde, 0xf6, 0x91, 0x77, 0xdb, 0x77, 0x63, 0x98, 0x4a, 0x1e, 0x89, 0x69, 0x28, 0xd5, 0xb4, 0x67,
0xfe, 0x51, 0x8a, 0x99, 0x9b, 0x13, 0x8f, 0xcd, 0x0f, 0xe1, 0x31, 0x59, 0xf4, 0xe3, 0xaa, 0x19,
0xd9, 0xb7, 0xff, 0x07, 0x00, 0x00, 0xff, 0xff, 0x81, 0x14, 0xee, 0xd1, 0x7b, 0x06, 0x00, 0x00,
}

View File

@@ -16,19 +16,31 @@
*
*/
package grpc
//go:generate ./regenerate.sh
// Package grpclb defines a grpclb balancer.
//
// To install grpclb balancer, import this package as:
// import _ "google.golang.org/grpc/balancer/grpclb"
package grpclb
import (
"errors"
"strconv"
"strings"
"sync"
"time"
durationpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/resolver"
)
@@ -38,7 +50,21 @@ const (
grpclbName = "grpclb"
)
func convertDuration(d *lbpb.Duration) time.Duration {
var (
// defaultBackoffConfig configures the backoff strategy that's used when the
// init handshake in the RPC is unsuccessful. It's not for the clientconn
// reconnect backoff.
//
// It has the same value as the default grpc.DefaultBackoffConfig.
//
// TODO: make backoff configurable.
defaultBackoffConfig = backoff.Exponential{
MaxDelay: 120 * time.Second,
}
errServerTerminatedConnection = errors.New("grpclb: failed to recv server list: server terminated connection")
)
func convertDuration(d *durationpb.Duration) time.Duration {
if d == nil {
return 0
}
@@ -49,16 +75,16 @@ func convertDuration(d *lbpb.Duration) time.Duration {
// Mostly copied from generated pb.go file.
// To avoid circular dependency.
type loadBalancerClient struct {
cc *ClientConn
cc *grpc.ClientConn
}
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption) (*balanceLoadClientStream, error) {
desc := &StreamDesc{
func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (*balanceLoadClientStream, error) {
desc := &grpc.StreamDesc{
StreamName: "BalanceLoad",
ServerStreams: true,
ClientStreams: true,
}
stream, err := NewClientStream(ctx, desc, c.cc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
stream, err := c.cc.NewStream(ctx, desc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...)
if err != nil {
return nil, err
}
@@ -67,7 +93,7 @@ func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...CallOption
}
type balanceLoadClientStream struct {
ClientStream
grpc.ClientStream
}
func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error {
@@ -88,16 +114,16 @@ func init() {
// newLBBuilder creates a builder for grpclb.
func newLBBuilder() balancer.Builder {
return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
return newLBBuilderWithFallbackTimeout(defaultFallbackTimeout)
}
// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// newLBBuilderWithFallbackTimeout creates a grpclb builder with the given
// fallbackTimeout. If no response is received from the remote balancer within
// fallbackTimeout, the backend addresses from the resolved address list will be
// used.
//
// Only call this function when a non-default fallback timeout is needed.
func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
func newLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder {
return &lbBuilder{
fallbackTimeout: fallbackTimeout,
}
@@ -127,27 +153,50 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal
}
lb := &lbBalancer{
cc: cc,
cc: newLBCacheClientConn(cc),
target: target,
opt: opt,
fallbackTimeout: b.fallbackTimeout,
doneCh: make(chan struct{}),
manualResolver: r,
csEvltr: &connectivityStateEvaluator{},
csEvltr: &balancer.ConnectivityStateEvaluator{},
subConns: make(map[resolver.Address]balancer.SubConn),
scStates: make(map[balancer.SubConn]connectivity.State),
picker: &errPicker{err: balancer.ErrNoSubConnAvailable},
clientStats: &rpcStats{},
clientStats: newRPCStats(),
backoff: defaultBackoffConfig, // TODO: make backoff configurable.
}
var err error
if opt.CredsBundle != nil {
lb.grpclbClientConnCreds, err = opt.CredsBundle.NewWithMode(internal.CredsBundleModeBalancer)
if err != nil {
grpclog.Warningf("lbBalancer: client connection creds NewWithMode failed: %v", err)
}
lb.grpclbBackendCreds, err = opt.CredsBundle.NewWithMode(internal.CredsBundleModeBackendFromBalancer)
if err != nil {
grpclog.Warningf("lbBalancer: backend creds NewWithMode failed: %v", err)
}
}
return lb
}
type lbBalancer struct {
cc balancer.ClientConn
target string
opt balancer.BuildOptions
cc *lbCacheClientConn
target string
opt balancer.BuildOptions
// grpclbClientConnCreds is the creds bundle to be used to connect to grpclb
// servers. If it's nil, use the TransportCredentials from BuildOptions
// instead.
grpclbClientConnCreds credentials.Bundle
// grpclbBackendCreds is the creds bundle to be used for addresses that are
// returned by grpclb server. If it's nil, don't set anything when creating
// SubConns.
grpclbBackendCreds credentials.Bundle
fallbackTimeout time.Duration
doneCh chan struct{}
@@ -156,7 +205,9 @@ type lbBalancer struct {
// send to remote LB ClientConn through this resolver.
manualResolver *lbManualResolver
// The ClientConn to talk to the remote balancer.
ccRemoteLB *ClientConn
ccRemoteLB *grpc.ClientConn
// backoff for calling remote balancer.
backoff backoff.Strategy
// Support client side load reporting. Each picker gets a reference to this,
// and will update its content.
@@ -173,7 +224,7 @@ type lbBalancer struct {
// but with only READY SCs will be gerenated.
backendAddrs []resolver.Address
// Roundrobin functionalities.
csEvltr *connectivityStateEvaluator
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn.
scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns.
@@ -220,7 +271,6 @@ func (lb *lbBalancer) regeneratePicker() {
subConns: readySCs,
stats: lb.clientStats,
}
return
}
func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
@@ -244,7 +294,7 @@ func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivi
}
oldAggrState := lb.state
lb.state = lb.csEvltr.recordTransition(oldS, s)
lb.state = lb.csEvltr.RecordTransition(oldS, s)
// Regenerate picker when one of the following happens:
// - this sc became ready from not-ready
@@ -257,7 +307,6 @@ func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivi
}
lb.cc.UpdateBalancerState(lb.state, lb.picker)
return
}
// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use
@@ -277,7 +326,7 @@ func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) {
return
}
lb.fallbackTimerExpired = true
lb.refreshSubConns(lb.resolvedBackendAddrs)
lb.refreshSubConns(lb.resolvedBackendAddrs, false)
lb.mu.Unlock()
}
@@ -324,7 +373,7 @@ func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
// This means we received a new list of resolved backends, and we are
// still in fallback mode. Need to update the list of backends we are
// using to the new list of backends.
lb.refreshSubConns(lb.resolvedBackendAddrs)
lb.refreshSubConns(lb.resolvedBackendAddrs, false)
}
lb.mu.Unlock()
}
@@ -339,4 +388,5 @@ func (lb *lbBalancer) Close() {
if lb.ccRemoteLB != nil {
lb.ccRemoteLB.Close()
}
lb.cc.close()
}

View File

@@ -16,7 +16,7 @@
*
*/
package grpc
package grpclb
import (
"sync"
@@ -24,55 +24,70 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/status"
)
// rpcStats is same as lbmpb.ClientStats, except that numCallsDropped is a map
// instead of a slice.
type rpcStats struct {
NumCallsStarted int64
NumCallsFinished int64
NumCallsFinishedWithDropForRateLimiting int64
NumCallsFinishedWithDropForLoadBalancing int64
NumCallsFinishedWithClientFailedToSend int64
NumCallsFinishedKnownReceived int64
// Only access the following fields atomically.
numCallsStarted int64
numCallsFinished int64
numCallsFinishedWithClientFailedToSend int64
numCallsFinishedKnownReceived int64
mu sync.Mutex
// map load_balance_token -> num_calls_dropped
numCallsDropped map[string]int64
}
func newRPCStats() *rpcStats {
return &rpcStats{
numCallsDropped: make(map[string]int64),
}
}
// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats.
func (s *rpcStats) toClientStats() *lbpb.ClientStats {
stats := &lbpb.ClientStats{
NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0),
NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0),
NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0),
NumCallsStarted: atomic.SwapInt64(&s.numCallsStarted, 0),
NumCallsFinished: atomic.SwapInt64(&s.numCallsFinished, 0),
NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.numCallsFinishedWithClientFailedToSend, 0),
NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.numCallsFinishedKnownReceived, 0),
}
s.mu.Lock()
dropped := s.numCallsDropped
s.numCallsDropped = make(map[string]int64)
s.mu.Unlock()
for token, count := range dropped {
stats.CallsFinishedWithDrop = append(stats.CallsFinishedWithDrop, &lbpb.ClientStatsPerToken{
LoadBalanceToken: token,
NumCalls: count,
})
}
return stats
}
func (s *rpcStats) dropForRateLimiting() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
}
func (s *rpcStats) dropForLoadBalancing() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
func (s *rpcStats) drop(token string) {
atomic.AddInt64(&s.numCallsStarted, 1)
s.mu.Lock()
s.numCallsDropped[token]++
s.mu.Unlock()
atomic.AddInt64(&s.numCallsFinished, 1)
}
func (s *rpcStats) failedToSend() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
atomic.AddInt64(&s.numCallsStarted, 1)
atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, 1)
atomic.AddInt64(&s.numCallsFinished, 1)
}
func (s *rpcStats) knownReceived() {
atomic.AddInt64(&s.NumCallsStarted, 1)
atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.NumCallsFinished, 1)
atomic.AddInt64(&s.numCallsStarted, 1)
atomic.AddInt64(&s.numCallsFinishedKnownReceived, 1)
atomic.AddInt64(&s.numCallsFinished, 1)
}
type errPicker struct {
@@ -131,12 +146,8 @@ func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balance
p.serverListNext = (p.serverListNext + 1) % len(p.serverList)
// If it's a drop, return an error and fail the RPC.
if s.DropForRateLimiting {
p.stats.dropForRateLimiting()
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}
if s.DropForLoadBalancing {
p.stats.dropForLoadBalancing()
if s.Drop {
p.stats.drop(s.LoadBalanceToken)
return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb")
}

View File

@@ -16,19 +16,24 @@
*
*/
package grpc
package grpclb
import (
"fmt"
"io"
"net"
"reflect"
"time"
timestamppb "github.com/golang/protobuf/ptypes/timestamp"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/connectivity"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
)
@@ -52,8 +57,8 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
lb.fullServerList = l.Servers
var backendAddrs []resolver.Address
for _, s := range l.Servers {
if s.DropForLoadBalancing || s.DropForRateLimiting {
for i, s := range l.Servers {
if s.Drop {
continue
}
@@ -69,27 +74,34 @@ func (lb *lbBalancer) processServerList(l *lbpb.ServerList) {
Addr: fmt.Sprintf("%s:%d", ipStr, s.Port),
Metadata: &md,
}
grpclog.Infof("lbBalancer: server list entry[%d]: ipStr:|%s|, port:|%d|, load balancer token:|%v|",
i, ipStr, s.Port, s.LoadBalanceToken)
backendAddrs = append(backendAddrs, addr)
}
// Call refreshSubConns to create/remove SubConns.
backendsUpdated := lb.refreshSubConns(backendAddrs)
// If no backend was updated, no SubConn will be newed/removed. But since
// the full serverList was different, there might be updates in drops or
// pick weights(different number of duplicates). We need to update picker
// with the fulllist.
if !backendsUpdated {
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
lb.refreshSubConns(backendAddrs, true)
// Regenerate and update picker no matter if there's update on backends (if
// any SubConn will be newed/removed). Because since the full serverList was
// different, there might be updates in drops or pick weights(different
// number of duplicates). We need to update picker with the fulllist.
//
// Now with cache, even if SubConn was newed/removed, there might be no
// state changes.
lb.regeneratePicker()
lb.cc.UpdateBalancerState(lb.state, lb.picker)
}
// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool
// indicating whether the backendAddrs are different from the cached
// backendAddrs (whether any SubConn was newed/removed).
// Caller must hold lb.mu.
func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool {
func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address, fromGRPCLBServer bool) bool {
opts := balancer.NewSubConnOptions{}
if fromGRPCLBServer {
opts.CredsBundle = lb.grpclbBackendCreds
}
lb.backendAddrs = nil
var backendsUpdated bool
// addrsSet is the set converted from backendAddrs, it's used to quick
@@ -106,13 +118,17 @@ func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool {
backendsUpdated = true
// Use addrWithMD to create the SubConn.
sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{})
sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, opts)
if err != nil {
grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err)
continue
}
lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map.
lb.scStates[sc] = connectivity.Idle
if _, ok := lb.scStates[sc]; !ok {
// Only set state of new sc to IDLE. The state could already be
// READY for cached SubConns.
lb.scStates[sc] = connectivity.Idle
}
sc.Connect()
}
}
@@ -136,6 +152,9 @@ func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error {
for {
reply, err := s.Recv()
if err != nil {
if err == io.EOF {
return errServerTerminatedConnection
}
return fmt.Errorf("grpclb: failed to recv server list: %v", err)
}
if serverList := reply.GetServerList(); serverList != nil {
@@ -155,7 +174,7 @@ func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.D
}
stats := lb.clientStats.toClientStats()
t := time.Now()
stats.Timestamp = &lbpb.Timestamp{
stats.Timestamp = &timestamppb.Timestamp{
Seconds: t.Unix(),
Nanos: int32(t.Nanosecond()),
}
@@ -168,13 +187,14 @@ func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.D
}
}
}
func (lb *lbBalancer) callRemoteBalancer() error {
func (lb *lbBalancer) callRemoteBalancer() (backoff bool, _ error) {
lbClient := &loadBalancerClient{cc: lb.ccRemoteLB}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream, err := lbClient.BalanceLoad(ctx, FailFast(false))
stream, err := lbClient.BalanceLoad(ctx, grpc.FailFast(false))
if err != nil {
return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err)
}
// grpclb handshake on the stream.
@@ -186,18 +206,18 @@ func (lb *lbBalancer) callRemoteBalancer() error {
},
}
if err := stream.Send(initReq); err != nil {
return fmt.Errorf("grpclb: failed to send init request: %v", err)
return true, fmt.Errorf("grpclb: failed to send init request: %v", err)
}
reply, err := stream.Recv()
if err != nil {
return fmt.Errorf("grpclb: failed to recv init response: %v", err)
return true, fmt.Errorf("grpclb: failed to recv init response: %v", err)
}
initResp := reply.GetInitialResponse()
if initResp == nil {
return fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
return true, fmt.Errorf("grpclb: reply from remote balancer did not include initial response")
}
if initResp.LoadBalancerDelegate != "" {
return fmt.Errorf("grpclb: Delegation is not supported")
return true, fmt.Errorf("grpclb: Delegation is not supported")
}
go func() {
@@ -205,47 +225,77 @@ func (lb *lbBalancer) callRemoteBalancer() error {
lb.sendLoadReport(stream, d)
}
}()
return lb.readServerList(stream)
// No backoff if init req/resp handshake was successful.
return false, lb.readServerList(stream)
}
func (lb *lbBalancer) watchRemoteBalancer() {
var retryCount int
for {
err := lb.callRemoteBalancer()
doBackoff, err := lb.callRemoteBalancer()
select {
case <-lb.doneCh:
return
default:
if err != nil {
grpclog.Error(err)
if err == errServerTerminatedConnection {
grpclog.Info(err)
} else {
grpclog.Warning(err)
}
}
}
if !doBackoff {
retryCount = 0
continue
}
timer := time.NewTimer(lb.backoff.Backoff(retryCount))
select {
case <-timer.C:
case <-lb.doneCh:
timer.Stop()
return
}
retryCount++
}
}
func (lb *lbBalancer) dialRemoteLB(remoteLBName string) {
var dopts []DialOption
var dopts []grpc.DialOption
if creds := lb.opt.DialCreds; creds != nil {
if err := creds.OverrideServerName(remoteLBName); err == nil {
dopts = append(dopts, WithTransportCredentials(creds))
dopts = append(dopts, grpc.WithTransportCredentials(creds))
} else {
grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err)
dopts = append(dopts, WithInsecure())
dopts = append(dopts, grpc.WithInsecure())
}
} else if bundle := lb.grpclbClientConnCreds; bundle != nil {
dopts = append(dopts, grpc.WithCredentialsBundle(bundle))
} else {
dopts = append(dopts, WithInsecure())
dopts = append(dopts, grpc.WithInsecure())
}
if lb.opt.Dialer != nil {
// WithDialer takes a different type of function, so we instead use a
// special DialOption here.
dopts = append(dopts, withContextDialer(lb.opt.Dialer))
wcd := internal.WithContextDialer.(func(func(context.Context, string) (net.Conn, error)) grpc.DialOption)
dopts = append(dopts, wcd(lb.opt.Dialer))
}
// Explicitly set pickfirst as the balancer.
dopts = append(dopts, WithBalancerName(PickFirstBalancerName))
dopts = append(dopts, withResolverBuilder(lb.manualResolver))
// Dial using manualResolver.Scheme, which is a random scheme generated
// when init grpclb. The target name is not important.
cc, err := Dial("grpclb:///grpclb.server", dopts...)
dopts = append(dopts, grpc.WithBalancerName(grpc.PickFirstBalancerName))
wrb := internal.WithResolverBuilder.(func(resolver.Builder) grpc.DialOption)
dopts = append(dopts, wrb(lb.manualResolver))
if channelz.IsOn() {
dopts = append(dopts, grpc.WithChannelzParentID(lb.opt.ChannelzParentID))
}
// DialContext using manualResolver.Scheme, which is a random scheme
// generated when init grpclb. The target scheme here is not important.
//
// The grpc dial target will be used by the creds (ALTS) as the authority,
// so it has to be set to remoteLBName that comes from resolver.
cc, err := grpc.DialContext(context.Background(), remoteLBName, dopts...)
if err != nil {
grpclog.Fatalf("failed to dial: %v", err)
}

View File

@@ -16,11 +16,7 @@
*
*/
//go:generate protoc --go_out=plugins=:$GOPATH grpc_lb_v1/messages/messages.proto
//go:generate protoc --go_out=plugins=grpc:$GOPATH grpc_lb_v1/service/service.proto
// Package grpclb_test is currently used only for grpclb testing.
package grpclb_test
package grpclb
import (
"errors"
@@ -30,27 +26,26 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/golang/protobuf/proto"
durationpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
"google.golang.org/grpc/test/leakcheck"
_ "google.golang.org/grpc/grpclog/glogger"
)
var (
@@ -127,18 +122,75 @@ func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
// merge merges the new client stats into current stats.
//
// It's a test-only method. rpcStats is defined in grpclb_picker.
func (s *rpcStats) merge(cs *lbpb.ClientStats) {
atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted)
atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished)
atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend)
atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived)
s.mu.Lock()
for _, perToken := range cs.CallsFinishedWithDrop {
s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
}
s.mu.Unlock()
}
func mapsEqual(a, b map[string]int64) bool {
if len(a) != len(b) {
return false
}
for k, v1 := range a {
if v2, ok := b[k]; !ok || v1 != v2 {
return false
}
}
return true
}
func atomicEqual(a, b *int64) bool {
return atomic.LoadInt64(a) == atomic.LoadInt64(b)
}
// equal compares two rpcStats.
//
// It's a test-only method. rpcStats is defined in grpclb_picker.
func (s *rpcStats) equal(o *rpcStats) bool {
if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) {
return false
}
if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) {
return false
}
if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) {
return false
}
if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
o.mu.Lock()
defer o.mu.Unlock()
if !mapsEqual(s.numCallsDropped, o.numCallsDropped) {
return false
}
return true
}
type remoteBalancer struct {
sls chan *lbmpb.ServerList
sls chan *lbpb.ServerList
statsDura time.Duration
done chan struct{}
mu sync.Mutex
stats lbmpb.ClientStats
stats *rpcStats
}
func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
return &remoteBalancer{
sls: make(chan *lbmpb.ServerList, 1),
done: make(chan struct{}),
sls: make(chan *lbpb.ServerList, 1),
done: make(chan struct{}),
stats: newRPCStats(),
}
}
@@ -147,7 +199,7 @@ func (b *remoteBalancer) stop() {
close(b.done)
}
func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer) error {
func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
req, err := stream.Recv()
if err != nil {
return err
@@ -156,10 +208,10 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer
if initReq.Name != beServerName {
return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
}
resp := &lbmpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_InitialResponse{
InitialResponse: &lbmpb.InitialLoadBalanceResponse{
ClientStatsReportInterval: &lbmpb.Duration{
resp := &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
InitialResponse: &lbpb.InitialLoadBalanceResponse{
ClientStatsReportInterval: &durationpb.Duration{
Seconds: int64(b.statsDura.Seconds()),
Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
},
@@ -172,25 +224,18 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer
go func() {
for {
var (
req *lbmpb.LoadBalanceRequest
req *lbpb.LoadBalanceRequest
err error
)
if req, err = stream.Recv(); err != nil {
return
}
b.mu.Lock()
b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted
b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished
b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting
b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing
b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend
b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived
b.mu.Unlock()
b.stats.merge(req.GetClientStats())
}
}()
for v := range b.sls {
resp = &lbmpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{
resp = &lbpb.LoadBalanceResponse{
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
ServerList: v,
},
}
@@ -288,12 +333,8 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er
sn: lbServerName,
}
lb = grpc.NewServer(grpc.Creds(lbCreds))
if err != nil {
err = fmt.Errorf("Failed to generate the port number %v", err)
return
}
ls = newRemoteBalancer(nil)
lbspb.RegisterLoadBalancerServer(lb, ls)
lbgrpc.RegisterLoadBalancerServer(lb, ls)
go func() {
lb.Serve(lbLis)
}()
@@ -327,14 +368,14 @@ func TestGRPCLB(t *testing.T) {
}
defer cleanup()
be := &lbmpb.Server{
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbmpb.Server
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbmpb.ServerList{
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
@@ -375,7 +416,7 @@ func TestGRPCLBWeighted(t *testing.T) {
}
defer cleanup()
beServers := []*lbmpb.Server{{
beServers := []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
@@ -411,14 +452,14 @@ func TestGRPCLBWeighted(t *testing.T) {
sequences := []string{"00101", "00011"}
for _, seq := range sequences {
var (
bes []*lbmpb.Server
bes []*lbpb.Server
p peer.Peer
result string
)
for _, s := range seq {
bes = append(bes, beServers[s-'0'])
}
tss.ls.sls <- &lbmpb.ServerList{Servers: bes}
tss.ls.sls <- &lbpb.ServerList{Servers: bes}
for i := 0; i < 1000; i++ {
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
@@ -444,14 +485,14 @@ func TestDropRequest(t *testing.T) {
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls <- &lbmpb.ServerList{
Servers: []*lbmpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
DropForLoadBalancing: false,
tss.ls.sls <- &lbpb.ServerList{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
Drop: false,
}, {
DropForLoadBalancing: true,
Drop: true,
}},
}
creds := serverNameCheckCreds{
@@ -473,11 +514,24 @@ func TestDropRequest(t *testing.T) {
ServerName: lbServerName,
}})
// The 1st, non-fail-fast RPC should succeed. This ensures both server
// connections are made, because the first one has DropForLoadBalancing set to true.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
// Wait for the 1st, non-fail-fast RPC to succeed. This ensures both server
// connections are made, because the first one has DropForLoadBalancing set
// to true.
var i int
for i = 0; i < 1000; i++ {
if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
break
}
time.Sleep(time.Millisecond)
}
if i >= 1000 {
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
}
select {
case <-ctx.Done():
t.Fatal("timed out", ctx.Err())
default:
}
for _, failfast := range []bool{true, false} {
for i := 0; i < 3; i++ {
// Even RPCs should fail, because the 2st backend has
@@ -512,14 +566,14 @@ func TestBalancerDisconnects(t *testing.T) {
}
defer cleanup()
be := &lbmpb.Server{
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbmpb.Server
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbmpb.ServerList{
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
@@ -587,7 +641,7 @@ const grpclbCustomFallbackName = "grpclb_with_custom_fallback_timeout"
func init() {
balancer.Register(&customGRPCLBBuilder{
Builder: grpc.NewLBBuilderWithFallbackTimeout(100 * time.Millisecond),
Builder: newLBBuilderWithFallbackTimeout(100 * time.Millisecond),
name: grpclbCustomFallbackName,
})
}
@@ -613,14 +667,14 @@ func TestFallback(t *testing.T) {
standaloneBEs := startBackends(beServerName, true, beLis)
defer stopBackends(standaloneBEs)
be := &lbmpb.Server{
be := &lbpb.Server{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbmpb.Server
var bes []*lbpb.Server
bes = append(bes, be)
sl := &lbmpb.ServerList{
sl := &lbpb.ServerList{
Servers: bes,
}
tss.ls.sls <- sl
@@ -691,14 +745,14 @@ func (failPreRPCCred) RequireTransportSecurity() bool {
return false
}
func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error {
if !proto.Equal(stats, expected) {
func checkStats(stats, expected *rpcStats) error {
if !stats.equal(expected) {
return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
}
return nil
}
func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats {
func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats {
defer leakcheck.Check(t)
r, cleanup := manual.GenerateAndRegisterManualResolver()
@@ -709,13 +763,12 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool
t.Fatalf("failed to create new load balancer: %v", err)
}
defer cleanup()
tss.ls.sls <- &lbmpb.ServerList{
Servers: []*lbmpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
DropForLoadBalancing: dropForLoadBalancing,
DropForRateLimiting: dropForRateLimiting,
tss.ls.sls <- &lbpb.ServerList{
Servers: []*lbpb.Server{{
IpAddress: tss.beIPs[0],
Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
Drop: drop,
}},
}
tss.ls.statsDura = 100 * time.Millisecond
@@ -740,9 +793,7 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool
runRPCs(cc)
time.Sleep(1 * time.Second)
tss.ls.mu.Lock()
stats := tss.ls.stats
tss.ls.mu.Unlock()
return stats
}
@@ -754,7 +805,7 @@ const (
func TestGRPCLBStatsUnarySuccess(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
@@ -765,19 +816,19 @@ func TestGRPCLBStatsUnarySuccess(t *testing.T) {
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC),
NumCallsFinished: int64(countRPC),
NumCallsFinishedKnownReceived: int64(countRPC),
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedKnownReceived: int64(countRPC),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
func TestGRPCLBStatsUnaryDrop(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
@@ -792,39 +843,11 @@ func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC + c),
NumCallsFinished: int64(countRPC + c),
NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
NumCallsFinishedWithClientFailedToSend: int64(c - 1),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), dropErrDesc) {
break
}
}
}
for i := 0; i < countRPC; i++ {
testC.EmptyCall(context.Background(), &testpb.Empty{})
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC + c),
NumCallsFinished: int64(countRPC + c),
NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
NumCallsFinishedWithClientFailedToSend: int64(c - 1),
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC + c),
numCallsFinished: int64(countRPC + c),
numCallsFinishedWithClientFailedToSend: int64(c - 1),
numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
}); err != nil {
t.Fatal(err)
}
@@ -832,22 +855,22 @@ func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
}
for i := 0; i < countRPC-1; i++ {
grpc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil, cc)
cc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil)
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC),
NumCallsFinished: int64(countRPC),
NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
NumCallsFinishedKnownReceived: 1,
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
numCallsFinishedKnownReceived: 1,
}); err != nil {
t.Fatal(err)
}
@@ -855,7 +878,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
@@ -880,19 +903,19 @@ func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC),
NumCallsFinished: int64(countRPC),
NumCallsFinishedKnownReceived: int64(countRPC),
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedKnownReceived: int64(countRPC),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
func TestGRPCLBStatsStreamingDrop(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
@@ -907,39 +930,11 @@ func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC + c),
NumCallsFinished: int64(countRPC + c),
NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
NumCallsFinishedWithClientFailedToSend: int64(c - 1),
}); err != nil {
t.Fatal(err)
}
}
func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
defer leakcheck.Check(t)
c := 0
stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
for {
c++
if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
if strings.Contains(err.Error(), dropErrDesc) {
break
}
}
}
for i := 0; i < countRPC; i++ {
testC.FullDuplexCall(context.Background())
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC + c),
NumCallsFinished: int64(countRPC + c),
NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
NumCallsFinishedWithClientFailedToSend: int64(c - 1),
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC + c),
numCallsFinished: int64(countRPC + c),
numCallsFinishedWithClientFailedToSend: int64(c - 1),
numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
}); err != nil {
t.Fatal(err)
}
@@ -947,7 +942,7 @@ func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
defer leakcheck.Check(t)
stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
testC := testpb.NewTestServiceClient(cc)
// The first non-failfast RPC succeeds, all connections are up.
stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
@@ -960,15 +955,15 @@ func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
}
}
for i := 0; i < countRPC-1; i++ {
grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, failtosendURI)
cc.NewStream(context.Background(), &grpc.StreamDesc{}, failtosendURI)
}
})
if err := checkStats(&stats, &lbmpb.ClientStats{
NumCallsStarted: int64(countRPC),
NumCallsFinished: int64(countRPC),
NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
NumCallsFinishedKnownReceived: 1,
if err := checkStats(stats, &rpcStats{
numCallsStarted: int64(countRPC),
numCallsFinished: int64(countRPC),
numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
numCallsFinishedKnownReceived: 1,
}); err != nil {
t.Fatal(err)
}

View File

@@ -16,10 +16,15 @@
*
*/
package grpc
package grpclb
import (
"fmt"
"sync"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
@@ -88,3 +93,122 @@ func (r *lbManualResolver) NewAddress(addrs []resolver.Address) {
func (r *lbManualResolver) NewServiceConfig(sc string) {
r.ccr.NewServiceConfig(sc)
}
const subConnCacheTime = time.Second * 10
// lbCacheClientConn is a wrapper balancer.ClientConn with a SubConn cache.
// SubConns will be kept in cache for subConnCacheTime before being removed.
//
// Its new and remove methods are updated to do cache first.
type lbCacheClientConn struct {
cc balancer.ClientConn
timeout time.Duration
mu sync.Mutex
// subConnCache only keeps subConns that are being deleted.
subConnCache map[resolver.Address]*subConnCacheEntry
subConnToAddr map[balancer.SubConn]resolver.Address
}
type subConnCacheEntry struct {
sc balancer.SubConn
cancel func()
abortDeleting bool
}
func newLBCacheClientConn(cc balancer.ClientConn) *lbCacheClientConn {
return &lbCacheClientConn{
cc: cc,
timeout: subConnCacheTime,
subConnCache: make(map[resolver.Address]*subConnCacheEntry),
subConnToAddr: make(map[balancer.SubConn]resolver.Address),
}
}
func (ccc *lbCacheClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
if len(addrs) != 1 {
return nil, fmt.Errorf("grpclb calling NewSubConn with addrs of length %v", len(addrs))
}
addrWithoutMD := addrs[0]
addrWithoutMD.Metadata = nil
ccc.mu.Lock()
defer ccc.mu.Unlock()
if entry, ok := ccc.subConnCache[addrWithoutMD]; ok {
// If entry is in subConnCache, the SubConn was being deleted.
// cancel function will never be nil.
entry.cancel()
delete(ccc.subConnCache, addrWithoutMD)
return entry.sc, nil
}
scNew, err := ccc.cc.NewSubConn(addrs, opts)
if err != nil {
return nil, err
}
ccc.subConnToAddr[scNew] = addrWithoutMD
return scNew, nil
}
func (ccc *lbCacheClientConn) RemoveSubConn(sc balancer.SubConn) {
ccc.mu.Lock()
defer ccc.mu.Unlock()
addr, ok := ccc.subConnToAddr[sc]
if !ok {
return
}
if entry, ok := ccc.subConnCache[addr]; ok {
if entry.sc != sc {
// This could happen if NewSubConn was called multiple times for the
// same address, and those SubConns are all removed. We remove sc
// immediately here.
delete(ccc.subConnToAddr, sc)
ccc.cc.RemoveSubConn(sc)
}
return
}
entry := &subConnCacheEntry{
sc: sc,
}
ccc.subConnCache[addr] = entry
timer := time.AfterFunc(ccc.timeout, func() {
ccc.mu.Lock()
if entry.abortDeleting {
return
}
ccc.cc.RemoveSubConn(sc)
delete(ccc.subConnToAddr, sc)
delete(ccc.subConnCache, addr)
ccc.mu.Unlock()
})
entry.cancel = func() {
if !timer.Stop() {
// If stop was not successful, the timer has fired (this can only
// happen in a race). But the deleting function is blocked on ccc.mu
// because the mutex was held by the caller of this function.
//
// Set abortDeleting to true to abort the deleting function. When
// the lock is released, the deleting function will acquire the
// lock, check the value of abortDeleting and return.
entry.abortDeleting = true
}
}
}
func (ccc *lbCacheClientConn) UpdateBalancerState(s connectivity.State, p balancer.Picker) {
ccc.cc.UpdateBalancerState(s, p)
}
func (ccc *lbCacheClientConn) close() {
ccc.mu.Lock()
// Only cancel all existing timers. There's no need to remove SubConns.
for _, entry := range ccc.subConnCache {
entry.cancel()
}
ccc.mu.Unlock()
}

View File

@@ -0,0 +1,219 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpclb
import (
"fmt"
"sync"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
type mockSubConn struct {
balancer.SubConn
}
type mockClientConn struct {
balancer.ClientConn
mu sync.Mutex
subConns map[balancer.SubConn]resolver.Address
}
func newMockClientConn() *mockClientConn {
return &mockClientConn{
subConns: make(map[balancer.SubConn]resolver.Address),
}
}
func (mcc *mockClientConn) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
sc := &mockSubConn{}
mcc.mu.Lock()
defer mcc.mu.Unlock()
mcc.subConns[sc] = addrs[0]
return sc, nil
}
func (mcc *mockClientConn) RemoveSubConn(sc balancer.SubConn) {
mcc.mu.Lock()
defer mcc.mu.Unlock()
delete(mcc.subConns, sc)
}
const testCacheTimeout = 100 * time.Millisecond
func checkMockCC(mcc *mockClientConn, scLen int) error {
mcc.mu.Lock()
defer mcc.mu.Unlock()
if len(mcc.subConns) != scLen {
return fmt.Errorf("mcc = %+v, want len(mcc.subConns) = %v", mcc.subConns, scLen)
}
return nil
}
func checkCacheCC(ccc *lbCacheClientConn, sccLen, sctaLen int) error {
ccc.mu.Lock()
defer ccc.mu.Unlock()
if len(ccc.subConnCache) != sccLen {
return fmt.Errorf("ccc = %+v, want len(ccc.subConnCache) = %v", ccc.subConnCache, sccLen)
}
if len(ccc.subConnToAddr) != sctaLen {
return fmt.Errorf("ccc = %+v, want len(ccc.subConnToAddr) = %v", ccc.subConnToAddr, sctaLen)
}
return nil
}
// Test that SubConn won't be immediately removed.
func TestLBCacheClientConnExpire(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
}
ccc := newLBCacheClientConn(mcc)
ccc.timeout = testCacheTimeout
if err := checkCacheCC(ccc, 0, 0); err != nil {
t.Fatal(err)
}
sc, _ := ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Should all become empty after timeout.
var err error
for i := 0; i < 2; i++ {
time.Sleep(testCacheTimeout)
err = checkMockCC(mcc, 0)
if err != nil {
continue
}
err = checkCacheCC(ccc, 0, 0)
if err != nil {
continue
}
}
if err != nil {
t.Fatal(err)
}
}
// Test that NewSubConn with the same address of a SubConn being removed will
// reuse the SubConn and cancel the removing.
func TestLBCacheClientConnReuse(t *testing.T) {
mcc := newMockClientConn()
if err := checkMockCC(mcc, 0); err != nil {
t.Fatal(err)
}
ccc := newLBCacheClientConn(mcc)
ccc.timeout = testCacheTimeout
if err := checkCacheCC(ccc, 0, 0); err != nil {
t.Fatal(err)
}
sc, _ := ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Recreate the old subconn, this should cancel the deleting process.
sc, _ = ccc.NewSubConn([]resolver.Address{{Addr: "address1"}}, balancer.NewSubConnOptions{})
// One subconn in MockCC.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// No subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 0, 1); err != nil {
t.Fatal(err)
}
var err error
// Should not become empty after 2*timeout.
time.Sleep(2 * testCacheTimeout)
err = checkMockCC(mcc, 1)
if err != nil {
t.Fatal(err)
}
err = checkCacheCC(ccc, 0, 1)
if err != nil {
t.Fatal(err)
}
// Call remove again, will delete after timeout.
ccc.RemoveSubConn(sc)
// One subconn in MockCC before timeout.
if err := checkMockCC(mcc, 1); err != nil {
t.Fatal(err)
}
// One subconn being deleted, and one in CacheCC.
if err := checkCacheCC(ccc, 1, 1); err != nil {
t.Fatal(err)
}
// Should all become empty after timeout.
for i := 0; i < 2; i++ {
time.Sleep(testCacheTimeout)
err = checkMockCC(mcc, 0)
if err != nil {
continue
}
err = checkCacheCC(ccc, 0, 0)
if err != nil {
continue
}
}
if err != nil {
t.Fatal(err)
}
}

33
vendor/google.golang.org/grpc/balancer/grpclb/regenerate.sh generated vendored Executable file
View File

@@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/lb/v1
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/lb/v1/load_balancer.proto > grpc/lb/v1/load_balancer.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/lb/v1/*.proto
popd
rm -f grpc_lb_v1/*.pb.go
cp "$TMP"/grpc/lb/v1/*.pb.go grpc_lb_v1/

View File

@@ -30,12 +30,12 @@ import (
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
"google.golang.org/grpc/test/leakcheck"
)
type testServer struct {

View File

@@ -115,7 +115,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui
return ccb
}
// watcher balancer functions sequencially, so the balancer can be implemeneted
// watcher balancer functions sequentially, so the balancer can be implemented
// lock-free.
func (ccb *ccBalancerWrapper) watcher() {
for {
@@ -197,7 +197,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer
if ccb.subConns == nil {
return nil, fmt.Errorf("grpc: ClientConn balancer wrapper was closed")
}
ac, err := ccb.cc.newAddrConn(addrs)
ac, err := ccb.cc.newAddrConn(addrs, opts)
if err != nil {
return nil, err
}
@@ -257,6 +257,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
}
if !acbw.ac.tryUpdateAddrs(addrs) {
cc := acbw.ac.cc
opts := acbw.ac.scopts
acbw.ac.mu.Lock()
// Set old ac.acbw to nil so the Shutdown state update will be ignored
// by balancer.
@@ -272,7 +273,7 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {
return
}
ac, err := cc.newAddrConn(addrs)
ac, err := cc.newAddrConn(addrs, opts)
if err != nil {
grpclog.Warningf("acBalancerWrapper: UpdateAddresses: failed to newAddrConn: %v", err)
return

View File

@@ -25,13 +25,39 @@ import (
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/test/leakcheck"
)
var _ balancer.Builder = &magicalLB{}
var _ balancer.Balancer = &magicalLB{}
// magicalLB is a ringer for grpclb. It is used to avoid circular dependencies on the grpclb package
type magicalLB struct{}
func (b *magicalLB) Name() string {
return "grpclb"
}
func (b *magicalLB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return b
}
func (b *magicalLB) HandleSubConnStateChange(balancer.SubConn, connectivity.State) {}
func (b *magicalLB) HandleResolvedAddrs([]resolver.Address, error) {}
func (b *magicalLB) Close() {}
func init() {
balancer.Register(&magicalLB{})
}
func checkPickFirst(cc *ClientConn, servers []*server) error {
var (
req = "port"
@@ -40,7 +66,7 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
)
connected := false
for i := 0; i < 5000; i++ {
if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); errorDesc(err) == servers[0].port {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == servers[0].port {
if connected {
// connected is set to false if peer is not server[0]. So if
// connected is true here, this is the second time we saw
@@ -58,7 +84,7 @@ func checkPickFirst(cc *ClientConn, servers []*server) error {
}
// The following RPCs should all succeed with the first server.
for i := 0; i < 3; i++ {
err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if errorDesc(err) != servers[0].port {
return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[0].port, err)
}
@@ -80,7 +106,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {
for _, s := range servers {
var up bool
for i := 0; i < 5000; i++ {
if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); errorDesc(err) == s.port {
if err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply); errorDesc(err) == s.port {
up = true
break
}
@@ -94,7 +120,7 @@ func checkRoundRobin(cc *ClientConn, servers []*server) error {
serverCount := len(servers)
for i := 0; i < 3*serverCount; i++ {
err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err = cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if errorDesc(err) != servers[i%serverCount].port {
return fmt.Errorf("Index %d: want peer %v, got peer %v", i, servers[i%serverCount].port, err)
}

View File

@@ -29,9 +29,9 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/leakcheck"
// V1 balancer tests use passthrough resolver instead of dns.
// TODO(bar) remove this when removing v1 balaner entirely.
@@ -39,12 +39,16 @@ import (
_ "google.golang.org/grpc/resolver/passthrough"
)
func pickFirstBalancerV1(r naming.Resolver) Balancer {
return &pickFirst{&roundRobin{r: r}}
}
type testWatcher struct {
// the channel to receives name resolution updates
update chan *naming.Update
// the side channel to get to know how many updates in a batch
side chan int
// the channel to notifiy update injector that the update reading is done
// the channel to notify update injector that the update reading is done
readDone chan int
}
@@ -130,7 +134,7 @@ func TestNameDiscovery(t *testing.T) {
defer cc.Close()
req := "port"
var reply string
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Inject the name resolution change to remove servers[0] and add servers[1].
@@ -146,7 +150,7 @@ func TestNameDiscovery(t *testing.T) {
r.w.inject(updates)
// Loop until the rpcs in flight talks to servers[1].
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@@ -163,7 +167,7 @@ func TestEmptyAddrs(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
@@ -177,7 +181,7 @@ func TestEmptyAddrs(t *testing.T) {
for {
time.Sleep(10 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil {
cancel()
break
}
@@ -206,7 +210,7 @@ func TestRoundRobin(t *testing.T) {
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@@ -219,14 +223,14 @@ func TestRoundRobin(t *testing.T) {
r.w.inject([]*naming.Update{u})
// Loop until both servers[2] are up.
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[2].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Check the incoming RPCs served in a round-robin manner.
for i := 0; i < 10; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[i%numServers].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[i%numServers].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port)
}
}
@@ -242,7 +246,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
@@ -254,7 +258,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies.
for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
@@ -267,7 +271,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@@ -275,7 +279,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@@ -302,7 +306,7 @@ func TestGetOnWaitChannel(t *testing.T) {
for {
var reply string
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
@@ -314,7 +318,7 @@ func TestGetOnWaitChannel(t *testing.T) {
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}()
@@ -350,7 +354,7 @@ func TestOneServerDown(t *testing.T) {
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@@ -374,7 +378,7 @@ func TestOneServerDown(t *testing.T) {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is killed around the same time to make it racy between balancer and gRPC internals.
Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
cc.Invoke(context.Background(), "/foo/bar", &req, &reply, FailFast(false))
wg.Done()
}()
}
@@ -403,7 +407,7 @@ func TestOneAddressRemoval(t *testing.T) {
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
@@ -433,8 +437,8 @@ func TestOneAddressRemoval(t *testing.T) {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want nil", err)
}
wg.Done()
}()
@@ -452,7 +456,7 @@ func checkServerUp(t *testing.T, currentServer *server) {
defer cc.Close()
var reply string
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == port {
break
}
time.Sleep(10 * time.Millisecond)
@@ -469,7 +473,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
@@ -483,7 +487,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) {
for {
time.Sleep(10 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply); err != nil {
cancel()
break
}
@@ -501,7 +505,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
}
defer cc.Close()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
@@ -513,7 +517,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies.
for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
if err := cc.Invoke(ctx, "/foo/bar", &expectedRequest, &reply, FailFast(false)); status.Code(err) == codes.DeadlineExceeded {
cancel()
break
}
@@ -526,7 +530,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@@ -534,7 +538,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
@@ -576,7 +580,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@@ -591,13 +595,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
r.w.inject([]*naming.Update{u})
// Loop until it changes to server[1]
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@@ -611,7 +615,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@@ -624,13 +628,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[2].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[2].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[2].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[2].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port)
}
time.Sleep(10 * time.Millisecond)
@@ -643,13 +647,13 @@ func TestPickFirstOrderAllServerUp(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@@ -689,7 +693,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
req := "port"
var reply string
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@@ -700,13 +704,13 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
servers[0].stop()
// Loop until it changes to server[1]
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@@ -721,7 +725,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
checkServerUp(t, servers[0])
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[1].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[1].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port)
}
time.Sleep(10 * time.Millisecond)
@@ -734,13 +738,13 @@ func TestPickFirstOrderOneServerDown(t *testing.T) {
}
r.w.inject([]*naming.Update{u})
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && errorDesc(err) == servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err != nil && errorDesc(err) == servers[0].port {
break
}
time.Sleep(1 * time.Second)
}
for i := 0; i < 20; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || errorDesc(err) != servers[0].port {
if err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply); err == nil || errorDesc(err) != servers[0].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port)
}
time.Sleep(10 * time.Millisecond)
@@ -794,8 +798,8 @@ func TestPickFirstOneAddressRemoval(t *testing.T) {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want nil", err)
}
wg.Done()
}()

View File

@@ -55,7 +55,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B
startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState),
csEvltr: &connectivityStateEvaluator{},
csEvltr: &balancer.ConnectivityStateEvaluator{},
state: connectivity.Idle,
}
cc.UpdateBalancerState(connectivity.Idle, bw)
@@ -80,10 +80,6 @@ type balancerWrapper struct {
cc balancer.ClientConn
targetAddr string // Target without the scheme.
// To aggregate the connectivity state.
csEvltr *connectivityStateEvaluator
state connectivity.State
mu sync.Mutex
conns map[resolver.Address]balancer.SubConn
connSt map[balancer.SubConn]*scState
@@ -92,6 +88,10 @@ type balancerWrapper struct {
// - NewSubConn is created, cc wants to notify balancer of state changes;
// - Build hasn't return, cc doesn't have access to balancer.
startCh chan struct{}
// To aggregate the connectivity state.
csEvltr *balancer.ConnectivityStateEvaluator
state connectivity.State
}
// lbWatcher watches the Notify channel of the balancer and manages
@@ -248,7 +248,7 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne
scSt.down(errConnClosing)
}
}
sa := bw.csEvltr.recordTransition(oldS, s)
sa := bw.csEvltr.RecordTransition(oldS, s)
if bw.state != sa {
bw.state = sa
}
@@ -257,7 +257,6 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne
// Remove state for this sc.
delete(bw.connSt, sc)
}
return
}
func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
@@ -270,7 +269,6 @@ func (bw *balancerWrapper) HandleResolvedAddrs([]resolver.Address, error) {
}
// There should be a resolver inside the balancer.
// All updates here, if any, are ignored.
return
}
func (bw *balancerWrapper) Close() {
@@ -282,7 +280,6 @@ func (bw *balancerWrapper) Close() {
close(bw.startCh)
}
bw.balancer.Close()
return
}
// The picker is the balancerWrapper itself.
@@ -329,47 +326,3 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions)
return sc, done, nil
}
// connectivityStateEvaluator gets updated by addrConns when their
// states transition, based on which it evaluates the state of
// ClientConn.
type connectivityStateEvaluator struct {
mu sync.Mutex
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
// recordTransition records state change happening in every subConn and based on
// that it evaluates what aggregated state should be.
// It can only transition between Ready, Connecting and TransientFailure. Other states,
// Idle and Shutdown are transitioned into by ClientConn; in the beginning of the connection
// before any subConn is created ClientConn is in idle state. In the end when ClientConn
// closes it is in Shutdown state.
// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state.
func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State {
cse.mu.Lock()
defer cse.mu.Unlock()
// Update counters.
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
cse.numReady += updateVal
case connectivity.Connecting:
cse.numConnecting += updateVal
case connectivity.TransientFailure:
cse.numTransientFailure += updateVal
}
}
// Evaluate.
if cse.numReady > 0 {
return connectivity.Ready
}
if cse.numConnecting > 0 {
return connectivity.Connecting
}
return connectivity.TransientFailure
}

View File

@@ -66,6 +66,8 @@ import (
"google.golang.org/grpc/benchmark/latency"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/test/bufconn"
)
const (
@@ -100,6 +102,7 @@ var (
memProfile, cpuProfile string
memProfileRate int
enableCompressor []bool
enableChannelz []bool
networkMode string
benchmarkResultFile string
networks = map[string]latency.Network{
@@ -111,14 +114,14 @@ var (
)
func unaryBenchmark(startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
caller, close := makeFuncUnary(benchFeatures)
defer close()
caller, cleanup := makeFuncUnary(benchFeatures)
defer cleanup()
runBenchmark(caller, startTimer, stopTimer, benchFeatures, benchtime, s)
}
func streamBenchmark(startTimer func(), stopTimer func(int32), benchFeatures stats.Features, benchtime time.Duration, s *stats.Stats) {
caller, close := makeFuncStream(benchFeatures)
defer close()
caller, cleanup := makeFuncStream(benchFeatures)
defer cleanup()
runBenchmark(caller, startTimer, stopTimer, benchFeatures, benchtime, s)
}
@@ -137,13 +140,31 @@ func makeFuncUnary(benchFeatures stats.Features) (func(int), func()) {
)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout)
}))
opts = append(opts, grpc.WithInsecure())
target, stopper := bm.StartServer(bm.ServerInfo{Addr: "localhost:0", Type: "protobuf", Network: nw}, sopts...)
conn := bm.NewClientConn(target, opts...)
var lis net.Listener
if *useBufconn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(
func(string, string, time.Duration) (net.Conn, error) {
return bcLis.Dial()
})("", "", 0)
}))
} else {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout)
}))
}
lis = nw.Listener(lis)
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
tc := testpb.NewBenchmarkServiceClient(conn)
return func(int) {
unaryCaller(tc, benchFeatures.ReqSizeBytes, benchFeatures.RespSizeBytes)
@@ -154,6 +175,7 @@ func makeFuncUnary(benchFeatures stats.Features) (func(int), func()) {
}
func makeFuncStream(benchFeatures stats.Features) (func(int), func()) {
// TODO: Refactor to remove duplication with makeFuncUnary.
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
opts := []grpc.DialOption{}
sopts := []grpc.ServerOption{}
@@ -168,13 +190,31 @@ func makeFuncStream(benchFeatures stats.Features) (func(int), func()) {
)
}
sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout)
}))
opts = append(opts, grpc.WithInsecure())
target, stopper := bm.StartServer(bm.ServerInfo{Addr: "localhost:0", Type: "protobuf", Network: nw}, sopts...)
conn := bm.NewClientConn(target, opts...)
var lis net.Listener
if *useBufconn {
bcLis := bufconn.Listen(256 * 1024)
lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(
func(string, string, time.Duration) (net.Conn, error) {
return bcLis.Dial()
})("", "", 0)
}))
} else {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout)
}))
}
lis = nw.Listener(lis)
stopper := bm.StartServer(bm.ServerInfo{Type: "protobuf", Listener: lis}, sopts...)
conn := bm.NewClientConn("" /* target not used */, opts...)
tc := testpb.NewBenchmarkServiceClient(conn)
streams := make([]testpb.BenchmarkService_StreamingCallClient, benchFeatures.MaxConcurrentCalls)
for i := 0; i < benchFeatures.MaxConcurrentCalls; i++ {
@@ -240,18 +280,21 @@ func runBenchmark(caller func(int), startTimer func(), stopTimer func(int32), be
stopTimer(count)
}
var useBufconn = flag.Bool("bufconn", false, "Use in-memory connection instead of system network I/O")
// Initiate main function to get settings of features.
func init() {
var (
workloads, traceMode, compressorMode, readLatency string
readKbps, readMtu, readMaxConcurrentCalls intSliceType
readReqSizeBytes, readRespSizeBytes intSliceType
workloads, traceMode, compressorMode, readLatency, channelzOn string
readKbps, readMtu, readMaxConcurrentCalls intSliceType
readReqSizeBytes, readRespSizeBytes intSliceType
)
flag.StringVar(&workloads, "workloads", workloadsAll,
fmt.Sprintf("Workloads to execute - One of: %v", strings.Join(allWorkloads, ", ")))
flag.StringVar(&traceMode, "trace", modeOff,
fmt.Sprintf("Trace mode - One of: %v", strings.Join(allTraceModes, ", ")))
flag.StringVar(&readLatency, "latency", "", "Simulated one-way network latency - may be a comma-separated list")
flag.StringVar(&channelzOn, "channelz", modeOff, "whether channelz should be turned on")
flag.DurationVar(&benchtime, "benchtime", time.Second, "Configures the amount of time to run each benchmark")
flag.Var(&readKbps, "kbps", "Simulated network throughput (in kbps) - may be a comma-separated list")
flag.Var(&readMtu, "mtu", "Simulated network MTU (Maximum Transmission Unit) - may be a comma-separated list")
@@ -287,6 +330,7 @@ func init() {
}
enableCompressor = setMode(compressorMode)
enableTrace = setMode(traceMode)
enableChannelz = setMode(channelzOn)
// Time input formats as (time + unit).
readTimeFromInput(&ltc, readLatency)
readIntFromIntSlice(&kbps, readKbps)
@@ -360,10 +404,10 @@ func readTimeFromInput(values *[]time.Duration, replace string) {
func main() {
before()
featuresPos := make([]int, 8)
featuresPos := make([]int, 9)
// 0:enableTracing 1:ltc 2:kbps 3:mtu 4:maxC 5:reqSize 6:respSize
featuresNum := []int{len(enableTrace), len(ltc), len(kbps), len(mtu),
len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes), len(enableCompressor)}
len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes), len(enableCompressor), len(enableChannelz)}
initalPos := make([]int, len(featuresPos))
s := stats.NewStats(10)
s.SortLatency()
@@ -380,7 +424,7 @@ func main() {
}
var stopTimer = func(count int32) {
runtime.ReadMemStats(&memStats)
results = testing.BenchmarkResult{N: int(count), T: time.Now().Sub(startTime),
results = testing.BenchmarkResult{N: int(count), T: time.Since(startTime),
Bytes: 0, MemAllocs: memStats.Mallocs - startAllocs, MemBytes: memStats.TotalAlloc - startBytes}
}
sharedPos := make([]bool, len(featuresPos))
@@ -404,9 +448,13 @@ func main() {
ReqSizeBytes: reqSizeBytes[featuresPos[5]],
RespSizeBytes: respSizeBytes[featuresPos[6]],
EnableCompressor: enableCompressor[featuresPos[7]],
EnableChannelz: enableChannelz[featuresPos[8]],
}
grpc.EnableTracing = enableTrace[featuresPos[0]]
if enableChannelz[featuresPos[8]] {
channelz.TurnOn()
}
if runMode[0] {
unaryBenchmark(startTimer, stopTimer, benchFeature, benchtime, s)
s.SetBenchmarkResult("Unary", benchFeature, results.N,

View File

@@ -65,7 +65,6 @@ func setPayload(p *testpb.Payload, t testpb.PayloadType, size int) {
}
p.Type = t
p.Body = body
return
}
func newPayload(t testpb.PayloadType, size int) *testpb.Payload {
@@ -136,9 +135,6 @@ func (s *byteBufServer) StreamingCall(stream testpb.BenchmarkService_StreamingCa
// ServerInfo contains the information to create a gRPC benchmark server.
type ServerInfo struct {
// Addr is the address of the server.
Addr string
// Type is the type of the server.
// It should be "protobuf" or "bytebuf".
Type string
@@ -148,21 +144,13 @@ type ServerInfo struct {
// For "bytebuf", it should be an int representing response size.
Metadata interface{}
// Network can simulate latency
Network *latency.Network
// Listener is the network listener for the server to use
Listener net.Listener
}
// StartServer starts a gRPC server serving a benchmark service according to info.
// It returns its listen address and a function to stop the server.
func StartServer(info ServerInfo, opts ...grpc.ServerOption) (string, func()) {
lis, err := net.Listen("tcp", info.Addr)
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
nw := info.Network
if nw != nil {
lis = nw.Listener(lis)
}
// It returns a function to stop the server.
func StartServer(info ServerInfo, opts ...grpc.ServerOption) func() {
opts = append(opts, grpc.WriteBufferSize(128*1024))
opts = append(opts, grpc.ReadBufferSize(128*1024))
s := grpc.NewServer(opts...)
@@ -178,8 +166,8 @@ func StartServer(info ServerInfo, opts ...grpc.ServerOption) (string, func()) {
default:
grpclog.Fatalf("failed to StartServer, unknown Type: %v", info.Type)
}
go s.Serve(lis)
return lis.Addr().String(), func() {
go s.Serve(info.Listener)
return func() {
s.Stop()
}
}
@@ -238,9 +226,14 @@ func DoByteBufStreamingRoundTrip(stream testpb.BenchmarkService_StreamingCallCli
// NewClientConn creates a gRPC client connection to addr.
func NewClientConn(addr string, opts ...grpc.DialOption) *grpc.ClientConn {
return NewClientConnWithContext(context.Background(), addr, opts...)
}
// NewClientConnWithContext creates a gRPC client connection to addr using ctx.
func NewClientConnWithContext(ctx context.Context, addr string, opts ...grpc.DialOption) *grpc.ClientConn {
opts = append(opts, grpc.WithWriteBufferSize(128*1024))
opts = append(opts, grpc.WithReadBufferSize(128*1024))
conn, err := grpc.Dial(addr, opts...)
conn, err := grpc.DialContext(ctx, addr, opts...)
if err != nil {
grpclog.Fatalf("NewClientConn(%q) failed to create a ClientConn %v", addr, err)
}
@@ -250,7 +243,13 @@ func NewClientConn(addr string, opts ...grpc.DialOption) *grpc.ClientConn {
func runUnary(b *testing.B, benchFeatures stats.Features) {
s := stats.AddStats(b, 38)
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
target, stopper := StartServer(ServerInfo{Addr: "localhost:0", Type: "protobuf", Network: nw}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
target := lis.Addr().String()
lis = nw.Listener(lis)
stopper := StartServer(ServerInfo{Type: "protobuf", Listener: lis}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
defer stopper()
conn := NewClientConn(
target, grpc.WithInsecure(),
@@ -298,7 +297,13 @@ func runUnary(b *testing.B, benchFeatures stats.Features) {
func runStream(b *testing.B, benchFeatures stats.Features) {
s := stats.AddStats(b, 38)
nw := &latency.Network{Kbps: benchFeatures.Kbps, Latency: benchFeatures.Latency, MTU: benchFeatures.Mtu}
target, stopper := StartServer(ServerInfo{Addr: "localhost:0", Type: "protobuf", Network: nw}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
target := lis.Addr().String()
lis = nw.Listener(lis)
stopper := StartServer(ServerInfo{Type: "protobuf", Listener: lis}, grpc.MaxConcurrentStreams(uint32(benchFeatures.MaxConcurrentCalls+1)))
defer stopper()
conn := NewClientConn(
target, grpc.WithInsecure(),

View File

@@ -30,81 +30,81 @@ import (
func BenchmarkClientStreamc1(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientStreamc8(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientStreamc64(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientStreamc512(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false, false})
}
func BenchmarkClientUnaryc1(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientUnaryc8(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientUnaryc64(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientUnaryc512(b *testing.B) {
grpc.EnableTracing = true
runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false})
runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec1(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec8(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec64(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientStreamNoTracec512(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec1(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec8(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec64(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false, false})
}
func BenchmarkClientUnaryNoTracec512(b *testing.B) {
grpc.EnableTracing = false
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false})
runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false})
}
func TestMain(m *testing.M) {

View File

@@ -20,10 +20,10 @@ package main
import (
"flag"
"math"
"net"
"net/http"
_ "net/http/pprof"
"fmt"
"os"
"runtime"
"runtime/pprof"
"sync"
"time"
@@ -33,148 +33,155 @@ import (
testpb "google.golang.org/grpc/benchmark/grpc_testing"
"google.golang.org/grpc/benchmark/stats"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
)
var (
server = flag.String("server", "", "The server address")
maxConcurrentRPCs = flag.Int("max_concurrent_rpcs", 1, "The max number of concurrent RPCs")
duration = flag.Int("duration", math.MaxInt32, "The duration in seconds to run the benchmark client")
trace = flag.Bool("trace", true, "Whether tracing is on")
rpcType = flag.Int("rpc_type", 0,
port = flag.String("port", "50051", "Localhost port to connect to.")
numRPC = flag.Int("r", 1, "The number of concurrent RPCs on each connection.")
numConn = flag.Int("c", 1, "The number of parallel connections.")
warmupDur = flag.Int("w", 10, "Warm-up duration in seconds")
duration = flag.Int("d", 60, "Benchmark duration in seconds")
rqSize = flag.Int("req", 1, "Request message size in bytes.")
rspSize = flag.Int("resp", 1, "Response message size in bytes.")
rpcType = flag.String("rpc_type", "unary",
`Configure different client rpc type. Valid options are:
0 : unary call;
1 : streaming call.`)
unary;
streaming.`)
testName = flag.String("test_name", "", "Name of the test used for creating profiles.")
wg sync.WaitGroup
hopts = stats.HistogramOptions{
NumBuckets: 2495,
GrowthFactor: .01,
}
mu sync.Mutex
hists []*stats.Histogram
)
func unaryCaller(client testpb.BenchmarkServiceClient) {
benchmark.DoUnaryCall(client, 1, 1)
}
func streamCaller(stream testpb.BenchmarkService_StreamingCallClient) {
benchmark.DoStreamingRoundTrip(stream, 1, 1)
}
func buildConnection() (s *stats.Stats, conn *grpc.ClientConn, tc testpb.BenchmarkServiceClient) {
s = stats.NewStats(256)
conn = benchmark.NewClientConn(*server)
tc = testpb.NewBenchmarkServiceClient(conn)
return s, conn, tc
}
func closeLoopUnary() {
s, conn, tc := buildConnection()
for i := 0; i < 100; i++ {
unaryCaller(tc)
}
ch := make(chan int, *maxConcurrentRPCs*4)
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(*maxConcurrentRPCs)
for i := 0; i < *maxConcurrentRPCs; i++ {
go func() {
for range ch {
start := time.Now()
unaryCaller(tc)
elapse := time.Since(start)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}()
}
// Stop the client when time is up.
done := make(chan struct{})
go func() {
<-time.After(time.Duration(*duration) * time.Second)
close(done)
}()
ok := true
for ok {
select {
case ch <- 0:
case <-done:
ok = false
}
}
close(ch)
wg.Wait()
conn.Close()
grpclog.Println(s.String())
}
func closeLoopStream() {
s, conn, tc := buildConnection()
ch := make(chan int, *maxConcurrentRPCs*4)
var (
mu sync.Mutex
wg sync.WaitGroup
)
wg.Add(*maxConcurrentRPCs)
// Distribute RPCs over maxConcurrentCalls workers.
for i := 0; i < *maxConcurrentRPCs; i++ {
go func() {
stream, err := tc.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err)
}
// Do some warm up.
for i := 0; i < 100; i++ {
streamCaller(stream)
}
for range ch {
start := time.Now()
streamCaller(stream)
elapse := time.Since(start)
mu.Lock()
s.Add(elapse)
mu.Unlock()
}
wg.Done()
}()
}
// Stop the client when time is up.
done := make(chan struct{})
go func() {
<-time.After(time.Duration(*duration) * time.Second)
close(done)
}()
ok := true
for ok {
select {
case ch <- 0:
case <-done:
ok = false
}
}
close(ch)
wg.Wait()
conn.Close()
grpclog.Println(s.String())
}
func main() {
flag.Parse()
grpc.EnableTracing = *trace
go func() {
lis, err := net.Listen("tcp", ":0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
grpclog.Println("Client profiling address: ", lis.Addr().String())
if err := http.Serve(lis, nil); err != nil {
grpclog.Fatalf("Failed to serve: %v", err)
}
}()
switch *rpcType {
case 0:
closeLoopUnary()
case 1:
closeLoopStream()
if *testName == "" {
grpclog.Fatalf("test_name not set")
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE,
ResponseSize: int32(*rspSize),
Payload: &testpb.Payload{
Type: testpb.PayloadType_COMPRESSABLE,
Body: make([]byte, *rqSize),
},
}
connectCtx, connectCancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer connectCancel()
ccs := buildConnections(connectCtx)
warmDeadline := time.Now().Add(time.Duration(*warmupDur) * time.Second)
endDeadline := warmDeadline.Add(time.Duration(*duration) * time.Second)
cf, err := os.Create("/tmp/" + *testName + ".cpu")
if err != nil {
grpclog.Fatalf("Error creating file: %v", err)
}
defer cf.Close()
pprof.StartCPUProfile(cf)
cpuBeg := syscall.GetCPUTime()
for _, cc := range ccs {
runWithConn(cc, req, warmDeadline, endDeadline)
}
wg.Wait()
cpu := time.Duration(syscall.GetCPUTime() - cpuBeg)
pprof.StopCPUProfile()
mf, err := os.Create("/tmp/" + *testName + ".mem")
if err != nil {
grpclog.Fatalf("Error creating file: %v", err)
}
defer mf.Close()
runtime.GC() // materialize all statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
grpclog.Fatalf("Error writing memory profile: %v", err)
}
hist := stats.NewHistogram(hopts)
for _, h := range hists {
hist.Merge(h)
}
parseHist(hist)
fmt.Println("Client CPU utilization:", cpu)
fmt.Println("Client CPU profile:", cf.Name())
fmt.Println("Client Mem Profile:", mf.Name())
}
func buildConnections(ctx context.Context) []*grpc.ClientConn {
ccs := make([]*grpc.ClientConn, *numConn)
for i := range ccs {
ccs[i] = benchmark.NewClientConnWithContext(ctx, "localhost:"+*port, grpc.WithInsecure(), grpc.WithBlock())
}
return ccs
}
func runWithConn(cc *grpc.ClientConn, req *testpb.SimpleRequest, warmDeadline, endDeadline time.Time) {
for i := 0; i < *numRPC; i++ {
wg.Add(1)
go func() {
defer wg.Done()
caller := makeCaller(cc, req)
hist := stats.NewHistogram(hopts)
for {
start := time.Now()
if start.After(endDeadline) {
mu.Lock()
hists = append(hists, hist)
mu.Unlock()
return
}
caller()
elapsed := time.Since(start)
if start.After(warmDeadline) {
hist.Add(elapsed.Nanoseconds())
}
}
}()
}
}
func makeCaller(cc *grpc.ClientConn, req *testpb.SimpleRequest) func() {
client := testpb.NewBenchmarkServiceClient(cc)
if *rpcType == "unary" {
return func() {
if _, err := client.UnaryCall(context.Background(), req); err != nil {
grpclog.Fatalf("RPC failed: %v", err)
}
}
}
stream, err := client.StreamingCall(context.Background())
if err != nil {
grpclog.Fatalf("RPC failed: %v", err)
}
return func() {
if err := stream.Send(req); err != nil {
grpclog.Fatalf("Streaming RPC failed to send: %v", err)
}
if _, err := stream.Recv(); err != nil {
grpclog.Fatalf("Streaming RPC failed to read: %v", err)
}
}
}
func parseHist(hist *stats.Histogram) {
fmt.Println("qps:", float64(hist.Count)/float64(*duration))
fmt.Printf("Latency: (50/90/99 %%ile): %v/%v/%v\n",
time.Duration(median(.5, hist)),
time.Duration(median(.9, hist)),
time.Duration(median(.99, hist)))
}
func median(percentile float64, h *stats.Histogram) int64 {
need := int64(float64(h.Count) * percentile)
have := int64(0)
for _, bucket := range h.Buckets {
count := bucket.Count
if have+count >= need {
percent := float64(need-have) / float64(count)
return int64((1.0-percent)*bucket.LowBound + percent*bucket.LowBound*(1.0+hopts.GrowthFactor))
}
have += bucket.Count
}
panic("should have found a bound")
}

File diff suppressed because it is too large Load Diff

View File

@@ -12,6 +12,12 @@ var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// The type of payload that should be returned.
type PayloadType int32
@@ -38,7 +44,9 @@ var PayloadType_value = map[string]int32{
func (x PayloadType) String() string {
return proto.EnumName(PayloadType_name, int32(x))
}
func (PayloadType) EnumDescriptor() ([]byte, []int) { return fileDescriptor1, []int{0} }
func (PayloadType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{0}
}
// Compression algorithms
type CompressionType int32
@@ -64,20 +72,44 @@ var CompressionType_value = map[string]int32{
func (x CompressionType) String() string {
return proto.EnumName(CompressionType_name, int32(x))
}
func (CompressionType) EnumDescriptor() ([]byte, []int) { return fileDescriptor1, []int{1} }
func (CompressionType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{1}
}
// A block of data, to simply increase gRPC message size.
type Payload struct {
// The type of data in body.
Type PayloadType `protobuf:"varint,1,opt,name=type,enum=grpc.testing.PayloadType" json:"type,omitempty"`
Type PayloadType `protobuf:"varint,1,opt,name=type,proto3,enum=grpc.testing.PayloadType" json:"type,omitempty"`
// Primary contents of payload.
Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"`
Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Payload) Reset() { *m = Payload{} }
func (m *Payload) String() string { return proto.CompactTextString(m) }
func (*Payload) ProtoMessage() {}
func (*Payload) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{0} }
func (m *Payload) Reset() { *m = Payload{} }
func (m *Payload) String() string { return proto.CompactTextString(m) }
func (*Payload) ProtoMessage() {}
func (*Payload) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{0}
}
func (m *Payload) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Payload.Unmarshal(m, b)
}
func (m *Payload) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Payload.Marshal(b, m, deterministic)
}
func (dst *Payload) XXX_Merge(src proto.Message) {
xxx_messageInfo_Payload.Merge(dst, src)
}
func (m *Payload) XXX_Size() int {
return xxx_messageInfo_Payload.Size(m)
}
func (m *Payload) XXX_DiscardUnknown() {
xxx_messageInfo_Payload.DiscardUnknown(m)
}
var xxx_messageInfo_Payload proto.InternalMessageInfo
func (m *Payload) GetType() PayloadType {
if m != nil {
@@ -96,14 +128,36 @@ func (m *Payload) GetBody() []byte {
// A protobuf representation for grpc status. This is used by test
// clients to specify a status that the server should attempt to return.
type EchoStatus struct {
Code int32 `protobuf:"varint,1,opt,name=code" json:"code,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"`
Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *EchoStatus) Reset() { *m = EchoStatus{} }
func (m *EchoStatus) String() string { return proto.CompactTextString(m) }
func (*EchoStatus) ProtoMessage() {}
func (*EchoStatus) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{1} }
func (m *EchoStatus) Reset() { *m = EchoStatus{} }
func (m *EchoStatus) String() string { return proto.CompactTextString(m) }
func (*EchoStatus) ProtoMessage() {}
func (*EchoStatus) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{1}
}
func (m *EchoStatus) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_EchoStatus.Unmarshal(m, b)
}
func (m *EchoStatus) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_EchoStatus.Marshal(b, m, deterministic)
}
func (dst *EchoStatus) XXX_Merge(src proto.Message) {
xxx_messageInfo_EchoStatus.Merge(dst, src)
}
func (m *EchoStatus) XXX_Size() int {
return xxx_messageInfo_EchoStatus.Size(m)
}
func (m *EchoStatus) XXX_DiscardUnknown() {
xxx_messageInfo_EchoStatus.DiscardUnknown(m)
}
var xxx_messageInfo_EchoStatus proto.InternalMessageInfo
func (m *EchoStatus) GetCode() int32 {
if m != nil {
@@ -123,26 +177,48 @@ func (m *EchoStatus) GetMessage() string {
type SimpleRequest struct {
// Desired payload type in the response from the server.
// If response_type is RANDOM, server randomly chooses one from other formats.
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,proto3,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
// Desired payload size in the response from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
ResponseSize int32 `protobuf:"varint,2,opt,name=response_size,json=responseSize" json:"response_size,omitempty"`
ResponseSize int32 `protobuf:"varint,2,opt,name=response_size,json=responseSize,proto3" json:"response_size,omitempty"`
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"`
Payload *Payload `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"`
// Whether SimpleResponse should include username.
FillUsername bool `protobuf:"varint,4,opt,name=fill_username,json=fillUsername" json:"fill_username,omitempty"`
FillUsername bool `protobuf:"varint,4,opt,name=fill_username,json=fillUsername,proto3" json:"fill_username,omitempty"`
// Whether SimpleResponse should include OAuth scope.
FillOauthScope bool `protobuf:"varint,5,opt,name=fill_oauth_scope,json=fillOauthScope" json:"fill_oauth_scope,omitempty"`
FillOauthScope bool `protobuf:"varint,5,opt,name=fill_oauth_scope,json=fillOauthScope,proto3" json:"fill_oauth_scope,omitempty"`
// Compression algorithm to be used by the server for the response (stream)
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,proto3,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
// Whether server should return a given status
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus" json:"response_status,omitempty"`
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus,proto3" json:"response_status,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleRequest) Reset() { *m = SimpleRequest{} }
func (m *SimpleRequest) String() string { return proto.CompactTextString(m) }
func (*SimpleRequest) ProtoMessage() {}
func (*SimpleRequest) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{2} }
func (m *SimpleRequest) Reset() { *m = SimpleRequest{} }
func (m *SimpleRequest) String() string { return proto.CompactTextString(m) }
func (*SimpleRequest) ProtoMessage() {}
func (*SimpleRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{2}
}
func (m *SimpleRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleRequest.Unmarshal(m, b)
}
func (m *SimpleRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleRequest.Marshal(b, m, deterministic)
}
func (dst *SimpleRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleRequest.Merge(dst, src)
}
func (m *SimpleRequest) XXX_Size() int {
return xxx_messageInfo_SimpleRequest.Size(m)
}
func (m *SimpleRequest) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleRequest.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleRequest proto.InternalMessageInfo
func (m *SimpleRequest) GetResponseType() PayloadType {
if m != nil {
@@ -196,18 +272,40 @@ func (m *SimpleRequest) GetResponseStatus() *EchoStatus {
// Unary response, as configured by the request.
type SimpleResponse struct {
// Payload to increase message size.
Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"`
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
// The user the request came from, for verifying authentication was
// successful when the client expected it.
Username string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"`
// OAuth scope.
OauthScope string `protobuf:"bytes,3,opt,name=oauth_scope,json=oauthScope" json:"oauth_scope,omitempty"`
OauthScope string `protobuf:"bytes,3,opt,name=oauth_scope,json=oauthScope,proto3" json:"oauth_scope,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleResponse) Reset() { *m = SimpleResponse{} }
func (m *SimpleResponse) String() string { return proto.CompactTextString(m) }
func (*SimpleResponse) ProtoMessage() {}
func (*SimpleResponse) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{3} }
func (m *SimpleResponse) Reset() { *m = SimpleResponse{} }
func (m *SimpleResponse) String() string { return proto.CompactTextString(m) }
func (*SimpleResponse) ProtoMessage() {}
func (*SimpleResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{3}
}
func (m *SimpleResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleResponse.Unmarshal(m, b)
}
func (m *SimpleResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleResponse.Marshal(b, m, deterministic)
}
func (dst *SimpleResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleResponse.Merge(dst, src)
}
func (m *SimpleResponse) XXX_Size() int {
return xxx_messageInfo_SimpleResponse.Size(m)
}
func (m *SimpleResponse) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleResponse.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleResponse proto.InternalMessageInfo
func (m *SimpleResponse) GetPayload() *Payload {
if m != nil {
@@ -233,13 +331,35 @@ func (m *SimpleResponse) GetOauthScope() string {
// Client-streaming request.
type StreamingInputCallRequest struct {
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"`
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingInputCallRequest) Reset() { *m = StreamingInputCallRequest{} }
func (m *StreamingInputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallRequest) ProtoMessage() {}
func (*StreamingInputCallRequest) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{4} }
func (m *StreamingInputCallRequest) Reset() { *m = StreamingInputCallRequest{} }
func (m *StreamingInputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallRequest) ProtoMessage() {}
func (*StreamingInputCallRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{4}
}
func (m *StreamingInputCallRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingInputCallRequest.Unmarshal(m, b)
}
func (m *StreamingInputCallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingInputCallRequest.Marshal(b, m, deterministic)
}
func (dst *StreamingInputCallRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingInputCallRequest.Merge(dst, src)
}
func (m *StreamingInputCallRequest) XXX_Size() int {
return xxx_messageInfo_StreamingInputCallRequest.Size(m)
}
func (m *StreamingInputCallRequest) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingInputCallRequest.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingInputCallRequest proto.InternalMessageInfo
func (m *StreamingInputCallRequest) GetPayload() *Payload {
if m != nil {
@@ -251,13 +371,35 @@ func (m *StreamingInputCallRequest) GetPayload() *Payload {
// Client-streaming response.
type StreamingInputCallResponse struct {
// Aggregated size of payloads received from the client.
AggregatedPayloadSize int32 `protobuf:"varint,1,opt,name=aggregated_payload_size,json=aggregatedPayloadSize" json:"aggregated_payload_size,omitempty"`
AggregatedPayloadSize int32 `protobuf:"varint,1,opt,name=aggregated_payload_size,json=aggregatedPayloadSize,proto3" json:"aggregated_payload_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingInputCallResponse) Reset() { *m = StreamingInputCallResponse{} }
func (m *StreamingInputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallResponse) ProtoMessage() {}
func (*StreamingInputCallResponse) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{5} }
func (m *StreamingInputCallResponse) Reset() { *m = StreamingInputCallResponse{} }
func (m *StreamingInputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingInputCallResponse) ProtoMessage() {}
func (*StreamingInputCallResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{5}
}
func (m *StreamingInputCallResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingInputCallResponse.Unmarshal(m, b)
}
func (m *StreamingInputCallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingInputCallResponse.Marshal(b, m, deterministic)
}
func (dst *StreamingInputCallResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingInputCallResponse.Merge(dst, src)
}
func (m *StreamingInputCallResponse) XXX_Size() int {
return xxx_messageInfo_StreamingInputCallResponse.Size(m)
}
func (m *StreamingInputCallResponse) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingInputCallResponse.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingInputCallResponse proto.InternalMessageInfo
func (m *StreamingInputCallResponse) GetAggregatedPayloadSize() int32 {
if m != nil {
@@ -270,16 +412,38 @@ func (m *StreamingInputCallResponse) GetAggregatedPayloadSize() int32 {
type ResponseParameters struct {
// Desired payload sizes in responses from the server.
// If response_type is COMPRESSABLE, this denotes the size before compression.
Size int32 `protobuf:"varint,1,opt,name=size" json:"size,omitempty"`
Size int32 `protobuf:"varint,1,opt,name=size,proto3" json:"size,omitempty"`
// Desired interval between consecutive responses in the response stream in
// microseconds.
IntervalUs int32 `protobuf:"varint,2,opt,name=interval_us,json=intervalUs" json:"interval_us,omitempty"`
IntervalUs int32 `protobuf:"varint,2,opt,name=interval_us,json=intervalUs,proto3" json:"interval_us,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ResponseParameters) Reset() { *m = ResponseParameters{} }
func (m *ResponseParameters) String() string { return proto.CompactTextString(m) }
func (*ResponseParameters) ProtoMessage() {}
func (*ResponseParameters) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{6} }
func (m *ResponseParameters) Reset() { *m = ResponseParameters{} }
func (m *ResponseParameters) String() string { return proto.CompactTextString(m) }
func (*ResponseParameters) ProtoMessage() {}
func (*ResponseParameters) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{6}
}
func (m *ResponseParameters) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ResponseParameters.Unmarshal(m, b)
}
func (m *ResponseParameters) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ResponseParameters.Marshal(b, m, deterministic)
}
func (dst *ResponseParameters) XXX_Merge(src proto.Message) {
xxx_messageInfo_ResponseParameters.Merge(dst, src)
}
func (m *ResponseParameters) XXX_Size() int {
return xxx_messageInfo_ResponseParameters.Size(m)
}
func (m *ResponseParameters) XXX_DiscardUnknown() {
xxx_messageInfo_ResponseParameters.DiscardUnknown(m)
}
var xxx_messageInfo_ResponseParameters proto.InternalMessageInfo
func (m *ResponseParameters) GetSize() int32 {
if m != nil {
@@ -301,21 +465,43 @@ type StreamingOutputCallRequest struct {
// If response_type is RANDOM, the payload from each response in the stream
// might be of different types. This is to simulate a mixed type of payload
// stream.
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,json=responseType,proto3,enum=grpc.testing.PayloadType" json:"response_type,omitempty"`
// Configuration for each expected response message.
ResponseParameters []*ResponseParameters `protobuf:"bytes,2,rep,name=response_parameters,json=responseParameters" json:"response_parameters,omitempty"`
ResponseParameters []*ResponseParameters `protobuf:"bytes,2,rep,name=response_parameters,json=responseParameters,proto3" json:"response_parameters,omitempty"`
// Optional input payload sent along with the request.
Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"`
Payload *Payload `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"`
// Compression algorithm to be used by the server for the response (stream)
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
ResponseCompression CompressionType `protobuf:"varint,6,opt,name=response_compression,json=responseCompression,proto3,enum=grpc.testing.CompressionType" json:"response_compression,omitempty"`
// Whether server should return a given status
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus" json:"response_status,omitempty"`
ResponseStatus *EchoStatus `protobuf:"bytes,7,opt,name=response_status,json=responseStatus,proto3" json:"response_status,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingOutputCallRequest) Reset() { *m = StreamingOutputCallRequest{} }
func (m *StreamingOutputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallRequest) ProtoMessage() {}
func (*StreamingOutputCallRequest) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{7} }
func (m *StreamingOutputCallRequest) Reset() { *m = StreamingOutputCallRequest{} }
func (m *StreamingOutputCallRequest) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallRequest) ProtoMessage() {}
func (*StreamingOutputCallRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{7}
}
func (m *StreamingOutputCallRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingOutputCallRequest.Unmarshal(m, b)
}
func (m *StreamingOutputCallRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingOutputCallRequest.Marshal(b, m, deterministic)
}
func (dst *StreamingOutputCallRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingOutputCallRequest.Merge(dst, src)
}
func (m *StreamingOutputCallRequest) XXX_Size() int {
return xxx_messageInfo_StreamingOutputCallRequest.Size(m)
}
func (m *StreamingOutputCallRequest) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingOutputCallRequest.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingOutputCallRequest proto.InternalMessageInfo
func (m *StreamingOutputCallRequest) GetResponseType() PayloadType {
if m != nil {
@@ -355,13 +541,35 @@ func (m *StreamingOutputCallRequest) GetResponseStatus() *EchoStatus {
// Server-streaming response, as configured by the request and parameters.
type StreamingOutputCallResponse struct {
// Payload to increase response size.
Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"`
Payload *Payload `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *StreamingOutputCallResponse) Reset() { *m = StreamingOutputCallResponse{} }
func (m *StreamingOutputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallResponse) ProtoMessage() {}
func (*StreamingOutputCallResponse) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{8} }
func (m *StreamingOutputCallResponse) Reset() { *m = StreamingOutputCallResponse{} }
func (m *StreamingOutputCallResponse) String() string { return proto.CompactTextString(m) }
func (*StreamingOutputCallResponse) ProtoMessage() {}
func (*StreamingOutputCallResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{8}
}
func (m *StreamingOutputCallResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_StreamingOutputCallResponse.Unmarshal(m, b)
}
func (m *StreamingOutputCallResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_StreamingOutputCallResponse.Marshal(b, m, deterministic)
}
func (dst *StreamingOutputCallResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_StreamingOutputCallResponse.Merge(dst, src)
}
func (m *StreamingOutputCallResponse) XXX_Size() int {
return xxx_messageInfo_StreamingOutputCallResponse.Size(m)
}
func (m *StreamingOutputCallResponse) XXX_DiscardUnknown() {
xxx_messageInfo_StreamingOutputCallResponse.DiscardUnknown(m)
}
var xxx_messageInfo_StreamingOutputCallResponse proto.InternalMessageInfo
func (m *StreamingOutputCallResponse) GetPayload() *Payload {
if m != nil {
@@ -373,13 +581,35 @@ func (m *StreamingOutputCallResponse) GetPayload() *Payload {
// For reconnect interop test only.
// Client tells server what reconnection parameters it used.
type ReconnectParams struct {
MaxReconnectBackoffMs int32 `protobuf:"varint,1,opt,name=max_reconnect_backoff_ms,json=maxReconnectBackoffMs" json:"max_reconnect_backoff_ms,omitempty"`
MaxReconnectBackoffMs int32 `protobuf:"varint,1,opt,name=max_reconnect_backoff_ms,json=maxReconnectBackoffMs,proto3" json:"max_reconnect_backoff_ms,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ReconnectParams) Reset() { *m = ReconnectParams{} }
func (m *ReconnectParams) String() string { return proto.CompactTextString(m) }
func (*ReconnectParams) ProtoMessage() {}
func (*ReconnectParams) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{9} }
func (m *ReconnectParams) Reset() { *m = ReconnectParams{} }
func (m *ReconnectParams) String() string { return proto.CompactTextString(m) }
func (*ReconnectParams) ProtoMessage() {}
func (*ReconnectParams) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{9}
}
func (m *ReconnectParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReconnectParams.Unmarshal(m, b)
}
func (m *ReconnectParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ReconnectParams.Marshal(b, m, deterministic)
}
func (dst *ReconnectParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ReconnectParams.Merge(dst, src)
}
func (m *ReconnectParams) XXX_Size() int {
return xxx_messageInfo_ReconnectParams.Size(m)
}
func (m *ReconnectParams) XXX_DiscardUnknown() {
xxx_messageInfo_ReconnectParams.DiscardUnknown(m)
}
var xxx_messageInfo_ReconnectParams proto.InternalMessageInfo
func (m *ReconnectParams) GetMaxReconnectBackoffMs() int32 {
if m != nil {
@@ -392,14 +622,36 @@ func (m *ReconnectParams) GetMaxReconnectBackoffMs() int32 {
// Server tells client whether its reconnects are following the spec and the
// reconnect backoffs it saw.
type ReconnectInfo struct {
Passed bool `protobuf:"varint,1,opt,name=passed" json:"passed,omitempty"`
BackoffMs []int32 `protobuf:"varint,2,rep,packed,name=backoff_ms,json=backoffMs" json:"backoff_ms,omitempty"`
Passed bool `protobuf:"varint,1,opt,name=passed,proto3" json:"passed,omitempty"`
BackoffMs []int32 `protobuf:"varint,2,rep,packed,name=backoff_ms,json=backoffMs,proto3" json:"backoff_ms,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ReconnectInfo) Reset() { *m = ReconnectInfo{} }
func (m *ReconnectInfo) String() string { return proto.CompactTextString(m) }
func (*ReconnectInfo) ProtoMessage() {}
func (*ReconnectInfo) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{10} }
func (m *ReconnectInfo) Reset() { *m = ReconnectInfo{} }
func (m *ReconnectInfo) String() string { return proto.CompactTextString(m) }
func (*ReconnectInfo) ProtoMessage() {}
func (*ReconnectInfo) Descriptor() ([]byte, []int) {
return fileDescriptor_messages_5c70222ad96bf232, []int{10}
}
func (m *ReconnectInfo) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ReconnectInfo.Unmarshal(m, b)
}
func (m *ReconnectInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ReconnectInfo.Marshal(b, m, deterministic)
}
func (dst *ReconnectInfo) XXX_Merge(src proto.Message) {
xxx_messageInfo_ReconnectInfo.Merge(dst, src)
}
func (m *ReconnectInfo) XXX_Size() int {
return xxx_messageInfo_ReconnectInfo.Size(m)
}
func (m *ReconnectInfo) XXX_DiscardUnknown() {
xxx_messageInfo_ReconnectInfo.DiscardUnknown(m)
}
var xxx_messageInfo_ReconnectInfo proto.InternalMessageInfo
func (m *ReconnectInfo) GetPassed() bool {
if m != nil {
@@ -431,9 +683,9 @@ func init() {
proto.RegisterEnum("grpc.testing.CompressionType", CompressionType_name, CompressionType_value)
}
func init() { proto.RegisterFile("messages.proto", fileDescriptor1) }
func init() { proto.RegisterFile("messages.proto", fileDescriptor_messages_5c70222ad96bf232) }
var fileDescriptor1 = []byte{
var fileDescriptor_messages_5c70222ad96bf232 = []byte{
// 652 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xcc, 0x55, 0x4d, 0x6f, 0xd3, 0x40,
0x10, 0xc5, 0xf9, 0xee, 0x24, 0x4d, 0xa3, 0x85, 0x82, 0x5b, 0x54, 0x11, 0x99, 0x4b, 0x54, 0x89,

View File

@@ -12,15 +12,43 @@ var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type ByteBufferParams struct {
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize" json:"resp_size,omitempty"`
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize,proto3" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize,proto3" json:"resp_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ByteBufferParams) Reset() { *m = ByteBufferParams{} }
func (m *ByteBufferParams) String() string { return proto.CompactTextString(m) }
func (*ByteBufferParams) ProtoMessage() {}
func (*ByteBufferParams) Descriptor() ([]byte, []int) { return fileDescriptor2, []int{0} }
func (m *ByteBufferParams) Reset() { *m = ByteBufferParams{} }
func (m *ByteBufferParams) String() string { return proto.CompactTextString(m) }
func (*ByteBufferParams) ProtoMessage() {}
func (*ByteBufferParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{0}
}
func (m *ByteBufferParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ByteBufferParams.Unmarshal(m, b)
}
func (m *ByteBufferParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ByteBufferParams.Marshal(b, m, deterministic)
}
func (dst *ByteBufferParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ByteBufferParams.Merge(dst, src)
}
func (m *ByteBufferParams) XXX_Size() int {
return xxx_messageInfo_ByteBufferParams.Size(m)
}
func (m *ByteBufferParams) XXX_DiscardUnknown() {
xxx_messageInfo_ByteBufferParams.DiscardUnknown(m)
}
var xxx_messageInfo_ByteBufferParams proto.InternalMessageInfo
func (m *ByteBufferParams) GetReqSize() int32 {
if m != nil {
@@ -37,14 +65,36 @@ func (m *ByteBufferParams) GetRespSize() int32 {
}
type SimpleProtoParams struct {
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize" json:"resp_size,omitempty"`
ReqSize int32 `protobuf:"varint,1,opt,name=req_size,json=reqSize,proto3" json:"req_size,omitempty"`
RespSize int32 `protobuf:"varint,2,opt,name=resp_size,json=respSize,proto3" json:"resp_size,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *SimpleProtoParams) Reset() { *m = SimpleProtoParams{} }
func (m *SimpleProtoParams) String() string { return proto.CompactTextString(m) }
func (*SimpleProtoParams) ProtoMessage() {}
func (*SimpleProtoParams) Descriptor() ([]byte, []int) { return fileDescriptor2, []int{1} }
func (m *SimpleProtoParams) Reset() { *m = SimpleProtoParams{} }
func (m *SimpleProtoParams) String() string { return proto.CompactTextString(m) }
func (*SimpleProtoParams) ProtoMessage() {}
func (*SimpleProtoParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{1}
}
func (m *SimpleProtoParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_SimpleProtoParams.Unmarshal(m, b)
}
func (m *SimpleProtoParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_SimpleProtoParams.Marshal(b, m, deterministic)
}
func (dst *SimpleProtoParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_SimpleProtoParams.Merge(dst, src)
}
func (m *SimpleProtoParams) XXX_Size() int {
return xxx_messageInfo_SimpleProtoParams.Size(m)
}
func (m *SimpleProtoParams) XXX_DiscardUnknown() {
xxx_messageInfo_SimpleProtoParams.DiscardUnknown(m)
}
var xxx_messageInfo_SimpleProtoParams proto.InternalMessageInfo
func (m *SimpleProtoParams) GetReqSize() int32 {
if m != nil {
@@ -61,42 +111,90 @@ func (m *SimpleProtoParams) GetRespSize() int32 {
}
type ComplexProtoParams struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ComplexProtoParams) Reset() { *m = ComplexProtoParams{} }
func (m *ComplexProtoParams) String() string { return proto.CompactTextString(m) }
func (*ComplexProtoParams) ProtoMessage() {}
func (*ComplexProtoParams) Descriptor() ([]byte, []int) { return fileDescriptor2, []int{2} }
func (m *ComplexProtoParams) Reset() { *m = ComplexProtoParams{} }
func (m *ComplexProtoParams) String() string { return proto.CompactTextString(m) }
func (*ComplexProtoParams) ProtoMessage() {}
func (*ComplexProtoParams) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{2}
}
func (m *ComplexProtoParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ComplexProtoParams.Unmarshal(m, b)
}
func (m *ComplexProtoParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ComplexProtoParams.Marshal(b, m, deterministic)
}
func (dst *ComplexProtoParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_ComplexProtoParams.Merge(dst, src)
}
func (m *ComplexProtoParams) XXX_Size() int {
return xxx_messageInfo_ComplexProtoParams.Size(m)
}
func (m *ComplexProtoParams) XXX_DiscardUnknown() {
xxx_messageInfo_ComplexProtoParams.DiscardUnknown(m)
}
var xxx_messageInfo_ComplexProtoParams proto.InternalMessageInfo
type PayloadConfig struct {
// Types that are valid to be assigned to Payload:
// *PayloadConfig_BytebufParams
// *PayloadConfig_SimpleParams
// *PayloadConfig_ComplexParams
Payload isPayloadConfig_Payload `protobuf_oneof:"payload"`
Payload isPayloadConfig_Payload `protobuf_oneof:"payload"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *PayloadConfig) Reset() { *m = PayloadConfig{} }
func (m *PayloadConfig) String() string { return proto.CompactTextString(m) }
func (*PayloadConfig) ProtoMessage() {}
func (*PayloadConfig) Descriptor() ([]byte, []int) { return fileDescriptor2, []int{3} }
func (m *PayloadConfig) Reset() { *m = PayloadConfig{} }
func (m *PayloadConfig) String() string { return proto.CompactTextString(m) }
func (*PayloadConfig) ProtoMessage() {}
func (*PayloadConfig) Descriptor() ([]byte, []int) {
return fileDescriptor_payloads_3abc71de35f06c83, []int{3}
}
func (m *PayloadConfig) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_PayloadConfig.Unmarshal(m, b)
}
func (m *PayloadConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_PayloadConfig.Marshal(b, m, deterministic)
}
func (dst *PayloadConfig) XXX_Merge(src proto.Message) {
xxx_messageInfo_PayloadConfig.Merge(dst, src)
}
func (m *PayloadConfig) XXX_Size() int {
return xxx_messageInfo_PayloadConfig.Size(m)
}
func (m *PayloadConfig) XXX_DiscardUnknown() {
xxx_messageInfo_PayloadConfig.DiscardUnknown(m)
}
var xxx_messageInfo_PayloadConfig proto.InternalMessageInfo
type isPayloadConfig_Payload interface {
isPayloadConfig_Payload()
}
type PayloadConfig_BytebufParams struct {
BytebufParams *ByteBufferParams `protobuf:"bytes,1,opt,name=bytebuf_params,json=bytebufParams,oneof"`
BytebufParams *ByteBufferParams `protobuf:"bytes,1,opt,name=bytebuf_params,json=bytebufParams,proto3,oneof"`
}
type PayloadConfig_SimpleParams struct {
SimpleParams *SimpleProtoParams `protobuf:"bytes,2,opt,name=simple_params,json=simpleParams,oneof"`
SimpleParams *SimpleProtoParams `protobuf:"bytes,2,opt,name=simple_params,json=simpleParams,proto3,oneof"`
}
type PayloadConfig_ComplexParams struct {
ComplexParams *ComplexProtoParams `protobuf:"bytes,3,opt,name=complex_params,json=complexParams,oneof"`
ComplexParams *ComplexProtoParams `protobuf:"bytes,3,opt,name=complex_params,json=complexParams,proto3,oneof"`
}
func (*PayloadConfig_BytebufParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_SimpleParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_SimpleParams) isPayloadConfig_Payload() {}
func (*PayloadConfig_ComplexParams) isPayloadConfig_Payload() {}
func (m *PayloadConfig) GetPayload() isPayloadConfig_Payload {
@@ -200,17 +298,17 @@ func _PayloadConfig_OneofSizer(msg proto.Message) (n int) {
switch x := m.Payload.(type) {
case *PayloadConfig_BytebufParams:
s := proto.Size(x.BytebufParams)
n += proto.SizeVarint(1<<3 | proto.WireBytes)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *PayloadConfig_SimpleParams:
s := proto.Size(x.SimpleParams)
n += proto.SizeVarint(2<<3 | proto.WireBytes)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *PayloadConfig_ComplexParams:
s := proto.Size(x.ComplexParams)
n += proto.SizeVarint(3<<3 | proto.WireBytes)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
@@ -227,9 +325,9 @@ func init() {
proto.RegisterType((*PayloadConfig)(nil), "grpc.testing.PayloadConfig")
}
func init() { proto.RegisterFile("payloads.proto", fileDescriptor2) }
func init() { proto.RegisterFile("payloads.proto", fileDescriptor_payloads_3abc71de35f06c83) }
var fileDescriptor2 = []byte{
var fileDescriptor_payloads_3abc71de35f06c83 = []byte{
// 254 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2b, 0x48, 0xac, 0xcc,
0xc9, 0x4f, 0x4c, 0x29, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x49, 0x2f, 0x2a, 0x48,

View File

@@ -17,6 +17,12 @@ var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
@@ -25,8 +31,9 @@ var _ grpc.ClientConn
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for BenchmarkService service
// BenchmarkServiceClient is the client API for BenchmarkService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type BenchmarkServiceClient interface {
// One request followed by one response.
// The server returns the client payload as-is.
@@ -46,7 +53,7 @@ func NewBenchmarkServiceClient(cc *grpc.ClientConn) BenchmarkServiceClient {
func (c *benchmarkServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) {
out := new(SimpleResponse)
err := grpc.Invoke(ctx, "/grpc.testing.BenchmarkService/UnaryCall", in, out, c.cc, opts...)
err := c.cc.Invoke(ctx, "/grpc.testing.BenchmarkService/UnaryCall", in, out, opts...)
if err != nil {
return nil, err
}
@@ -54,7 +61,7 @@ func (c *benchmarkServiceClient) UnaryCall(ctx context.Context, in *SimpleReques
}
func (c *benchmarkServiceClient) StreamingCall(ctx context.Context, opts ...grpc.CallOption) (BenchmarkService_StreamingCallClient, error) {
stream, err := grpc.NewClientStream(ctx, &_BenchmarkService_serviceDesc.Streams[0], c.cc, "/grpc.testing.BenchmarkService/StreamingCall", opts...)
stream, err := c.cc.NewStream(ctx, &_BenchmarkService_serviceDesc.Streams[0], "/grpc.testing.BenchmarkService/StreamingCall", opts...)
if err != nil {
return nil, err
}
@@ -84,8 +91,7 @@ func (x *benchmarkServiceStreamingCallClient) Recv() (*SimpleResponse, error) {
return m, nil
}
// Server API for BenchmarkService service
// BenchmarkServiceServer is the server API for BenchmarkService service.
type BenchmarkServiceServer interface {
// One request followed by one response.
// The server returns the client payload as-is.
@@ -163,8 +169,9 @@ var _BenchmarkService_serviceDesc = grpc.ServiceDesc{
Metadata: "services.proto",
}
// Client API for WorkerService service
// WorkerServiceClient is the client API for WorkerService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type WorkerServiceClient interface {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
@@ -195,7 +202,7 @@ func NewWorkerServiceClient(cc *grpc.ClientConn) WorkerServiceClient {
}
func (c *workerServiceClient) RunServer(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunServerClient, error) {
stream, err := grpc.NewClientStream(ctx, &_WorkerService_serviceDesc.Streams[0], c.cc, "/grpc.testing.WorkerService/RunServer", opts...)
stream, err := c.cc.NewStream(ctx, &_WorkerService_serviceDesc.Streams[0], "/grpc.testing.WorkerService/RunServer", opts...)
if err != nil {
return nil, err
}
@@ -226,7 +233,7 @@ func (x *workerServiceRunServerClient) Recv() (*ServerStatus, error) {
}
func (c *workerServiceClient) RunClient(ctx context.Context, opts ...grpc.CallOption) (WorkerService_RunClientClient, error) {
stream, err := grpc.NewClientStream(ctx, &_WorkerService_serviceDesc.Streams[1], c.cc, "/grpc.testing.WorkerService/RunClient", opts...)
stream, err := c.cc.NewStream(ctx, &_WorkerService_serviceDesc.Streams[1], "/grpc.testing.WorkerService/RunClient", opts...)
if err != nil {
return nil, err
}
@@ -258,7 +265,7 @@ func (x *workerServiceRunClientClient) Recv() (*ClientStatus, error) {
func (c *workerServiceClient) CoreCount(ctx context.Context, in *CoreRequest, opts ...grpc.CallOption) (*CoreResponse, error) {
out := new(CoreResponse)
err := grpc.Invoke(ctx, "/grpc.testing.WorkerService/CoreCount", in, out, c.cc, opts...)
err := c.cc.Invoke(ctx, "/grpc.testing.WorkerService/CoreCount", in, out, opts...)
if err != nil {
return nil, err
}
@@ -267,15 +274,14 @@ func (c *workerServiceClient) CoreCount(ctx context.Context, in *CoreRequest, op
func (c *workerServiceClient) QuitWorker(ctx context.Context, in *Void, opts ...grpc.CallOption) (*Void, error) {
out := new(Void)
err := grpc.Invoke(ctx, "/grpc.testing.WorkerService/QuitWorker", in, out, c.cc, opts...)
err := c.cc.Invoke(ctx, "/grpc.testing.WorkerService/QuitWorker", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for WorkerService service
// WorkerServiceServer is the server API for WorkerService service.
type WorkerServiceServer interface {
// Start server with specified workload.
// First request sent specifies the ServerConfig followed by ServerStatus
@@ -419,9 +425,9 @@ var _WorkerService_serviceDesc = grpc.ServiceDesc{
Metadata: "services.proto",
}
func init() { proto.RegisterFile("services.proto", fileDescriptor3) }
func init() { proto.RegisterFile("services.proto", fileDescriptor_services_bf68f4d7cbd0e0a1) }
var fileDescriptor3 = []byte{
var fileDescriptor_services_bf68f4d7cbd0e0a1 = []byte{
// 255 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x91, 0xc1, 0x4a, 0xc4, 0x30,
0x10, 0x86, 0xa9, 0x07, 0xa1, 0xc1, 0x2e, 0x92, 0x93, 0x46, 0x1f, 0xc0, 0x53, 0x91, 0xd5, 0x17,

View File

@@ -12,20 +12,48 @@ var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type ServerStats struct {
// wall clock time change in seconds since last reset
TimeElapsed float64 `protobuf:"fixed64,1,opt,name=time_elapsed,json=timeElapsed" json:"time_elapsed,omitempty"`
TimeElapsed float64 `protobuf:"fixed64,1,opt,name=time_elapsed,json=timeElapsed,proto3" json:"time_elapsed,omitempty"`
// change in user time (in seconds) used by the server since last reset
TimeUser float64 `protobuf:"fixed64,2,opt,name=time_user,json=timeUser" json:"time_user,omitempty"`
TimeUser float64 `protobuf:"fixed64,2,opt,name=time_user,json=timeUser,proto3" json:"time_user,omitempty"`
// change in server time (in seconds) used by the server process and all
// threads since last reset
TimeSystem float64 `protobuf:"fixed64,3,opt,name=time_system,json=timeSystem" json:"time_system,omitempty"`
TimeSystem float64 `protobuf:"fixed64,3,opt,name=time_system,json=timeSystem,proto3" json:"time_system,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerStats) Reset() { *m = ServerStats{} }
func (m *ServerStats) String() string { return proto.CompactTextString(m) }
func (*ServerStats) ProtoMessage() {}
func (*ServerStats) Descriptor() ([]byte, []int) { return fileDescriptor4, []int{0} }
func (m *ServerStats) Reset() { *m = ServerStats{} }
func (m *ServerStats) String() string { return proto.CompactTextString(m) }
func (*ServerStats) ProtoMessage() {}
func (*ServerStats) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{0}
}
func (m *ServerStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerStats.Unmarshal(m, b)
}
func (m *ServerStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerStats.Marshal(b, m, deterministic)
}
func (dst *ServerStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerStats.Merge(dst, src)
}
func (m *ServerStats) XXX_Size() int {
return xxx_messageInfo_ServerStats.Size(m)
}
func (m *ServerStats) XXX_DiscardUnknown() {
xxx_messageInfo_ServerStats.DiscardUnknown(m)
}
var xxx_messageInfo_ServerStats proto.InternalMessageInfo
func (m *ServerStats) GetTimeElapsed() float64 {
if m != nil {
@@ -50,14 +78,36 @@ func (m *ServerStats) GetTimeSystem() float64 {
// Histogram params based on grpc/support/histogram.c
type HistogramParams struct {
Resolution float64 `protobuf:"fixed64,1,opt,name=resolution" json:"resolution,omitempty"`
MaxPossible float64 `protobuf:"fixed64,2,opt,name=max_possible,json=maxPossible" json:"max_possible,omitempty"`
Resolution float64 `protobuf:"fixed64,1,opt,name=resolution,proto3" json:"resolution,omitempty"`
MaxPossible float64 `protobuf:"fixed64,2,opt,name=max_possible,json=maxPossible,proto3" json:"max_possible,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *HistogramParams) Reset() { *m = HistogramParams{} }
func (m *HistogramParams) String() string { return proto.CompactTextString(m) }
func (*HistogramParams) ProtoMessage() {}
func (*HistogramParams) Descriptor() ([]byte, []int) { return fileDescriptor4, []int{1} }
func (m *HistogramParams) Reset() { *m = HistogramParams{} }
func (m *HistogramParams) String() string { return proto.CompactTextString(m) }
func (*HistogramParams) ProtoMessage() {}
func (*HistogramParams) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{1}
}
func (m *HistogramParams) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_HistogramParams.Unmarshal(m, b)
}
func (m *HistogramParams) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_HistogramParams.Marshal(b, m, deterministic)
}
func (dst *HistogramParams) XXX_Merge(src proto.Message) {
xxx_messageInfo_HistogramParams.Merge(dst, src)
}
func (m *HistogramParams) XXX_Size() int {
return xxx_messageInfo_HistogramParams.Size(m)
}
func (m *HistogramParams) XXX_DiscardUnknown() {
xxx_messageInfo_HistogramParams.DiscardUnknown(m)
}
var xxx_messageInfo_HistogramParams proto.InternalMessageInfo
func (m *HistogramParams) GetResolution() float64 {
if m != nil {
@@ -75,18 +125,40 @@ func (m *HistogramParams) GetMaxPossible() float64 {
// Histogram data based on grpc/support/histogram.c
type HistogramData struct {
Bucket []uint32 `protobuf:"varint,1,rep,packed,name=bucket" json:"bucket,omitempty"`
MinSeen float64 `protobuf:"fixed64,2,opt,name=min_seen,json=minSeen" json:"min_seen,omitempty"`
MaxSeen float64 `protobuf:"fixed64,3,opt,name=max_seen,json=maxSeen" json:"max_seen,omitempty"`
Sum float64 `protobuf:"fixed64,4,opt,name=sum" json:"sum,omitempty"`
SumOfSquares float64 `protobuf:"fixed64,5,opt,name=sum_of_squares,json=sumOfSquares" json:"sum_of_squares,omitempty"`
Count float64 `protobuf:"fixed64,6,opt,name=count" json:"count,omitempty"`
Bucket []uint32 `protobuf:"varint,1,rep,packed,name=bucket,proto3" json:"bucket,omitempty"`
MinSeen float64 `protobuf:"fixed64,2,opt,name=min_seen,json=minSeen,proto3" json:"min_seen,omitempty"`
MaxSeen float64 `protobuf:"fixed64,3,opt,name=max_seen,json=maxSeen,proto3" json:"max_seen,omitempty"`
Sum float64 `protobuf:"fixed64,4,opt,name=sum,proto3" json:"sum,omitempty"`
SumOfSquares float64 `protobuf:"fixed64,5,opt,name=sum_of_squares,json=sumOfSquares,proto3" json:"sum_of_squares,omitempty"`
Count float64 `protobuf:"fixed64,6,opt,name=count,proto3" json:"count,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *HistogramData) Reset() { *m = HistogramData{} }
func (m *HistogramData) String() string { return proto.CompactTextString(m) }
func (*HistogramData) ProtoMessage() {}
func (*HistogramData) Descriptor() ([]byte, []int) { return fileDescriptor4, []int{2} }
func (m *HistogramData) Reset() { *m = HistogramData{} }
func (m *HistogramData) String() string { return proto.CompactTextString(m) }
func (*HistogramData) ProtoMessage() {}
func (*HistogramData) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{2}
}
func (m *HistogramData) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_HistogramData.Unmarshal(m, b)
}
func (m *HistogramData) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_HistogramData.Marshal(b, m, deterministic)
}
func (dst *HistogramData) XXX_Merge(src proto.Message) {
xxx_messageInfo_HistogramData.Merge(dst, src)
}
func (m *HistogramData) XXX_Size() int {
return xxx_messageInfo_HistogramData.Size(m)
}
func (m *HistogramData) XXX_DiscardUnknown() {
xxx_messageInfo_HistogramData.DiscardUnknown(m)
}
var xxx_messageInfo_HistogramData proto.InternalMessageInfo
func (m *HistogramData) GetBucket() []uint32 {
if m != nil {
@@ -132,17 +204,39 @@ func (m *HistogramData) GetCount() float64 {
type ClientStats struct {
// Latency histogram. Data points are in nanoseconds.
Latencies *HistogramData `protobuf:"bytes,1,opt,name=latencies" json:"latencies,omitempty"`
Latencies *HistogramData `protobuf:"bytes,1,opt,name=latencies,proto3" json:"latencies,omitempty"`
// See ServerStats for details.
TimeElapsed float64 `protobuf:"fixed64,2,opt,name=time_elapsed,json=timeElapsed" json:"time_elapsed,omitempty"`
TimeUser float64 `protobuf:"fixed64,3,opt,name=time_user,json=timeUser" json:"time_user,omitempty"`
TimeSystem float64 `protobuf:"fixed64,4,opt,name=time_system,json=timeSystem" json:"time_system,omitempty"`
TimeElapsed float64 `protobuf:"fixed64,2,opt,name=time_elapsed,json=timeElapsed,proto3" json:"time_elapsed,omitempty"`
TimeUser float64 `protobuf:"fixed64,3,opt,name=time_user,json=timeUser,proto3" json:"time_user,omitempty"`
TimeSystem float64 `protobuf:"fixed64,4,opt,name=time_system,json=timeSystem,proto3" json:"time_system,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) { return fileDescriptor4, []int{3} }
func (m *ClientStats) Reset() { *m = ClientStats{} }
func (m *ClientStats) String() string { return proto.CompactTextString(m) }
func (*ClientStats) ProtoMessage() {}
func (*ClientStats) Descriptor() ([]byte, []int) {
return fileDescriptor_stats_8ba831c0cb3c3440, []int{3}
}
func (m *ClientStats) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientStats.Unmarshal(m, b)
}
func (m *ClientStats) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientStats.Marshal(b, m, deterministic)
}
func (dst *ClientStats) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientStats.Merge(dst, src)
}
func (m *ClientStats) XXX_Size() int {
return xxx_messageInfo_ClientStats.Size(m)
}
func (m *ClientStats) XXX_DiscardUnknown() {
xxx_messageInfo_ClientStats.DiscardUnknown(m)
}
var xxx_messageInfo_ClientStats proto.InternalMessageInfo
func (m *ClientStats) GetLatencies() *HistogramData {
if m != nil {
@@ -179,9 +273,9 @@ func init() {
proto.RegisterType((*ClientStats)(nil), "grpc.testing.ClientStats")
}
func init() { proto.RegisterFile("stats.proto", fileDescriptor4) }
func init() { proto.RegisterFile("stats.proto", fileDescriptor_stats_8ba831c0cb3c3440) }
var fileDescriptor4 = []byte{
var fileDescriptor_stats_8ba831c0cb3c3440 = []byte{
// 341 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x92, 0xc1, 0x4a, 0xeb, 0x40,
0x14, 0x86, 0x49, 0xd3, 0xf6, 0xb6, 0x27, 0xed, 0xbd, 0x97, 0x41, 0x24, 0x52, 0xd0, 0x1a, 0x5c,

187
vendor/google.golang.org/grpc/benchmark/run_bench.sh generated vendored Executable file
View File

@@ -0,0 +1,187 @@
#!/bin/bash
rpcs=(1)
conns=(1)
warmup=10
dur=10
reqs=(1)
resps=(1)
rpc_types=(unary)
# idx[0] = idx value for rpcs
# idx[1] = idx value for conns
# idx[2] = idx value for reqs
# idx[3] = idx value for resps
# idx[4] = idx value for rpc_types
idx=(0 0 0 0 0)
idx_max=(1 1 1 1 1)
inc()
{
for i in $(seq $((${#idx[@]}-1)) -1 0); do
idx[${i}]=$((${idx[${i}]}+1))
if [ ${idx[${i}]} == ${idx_max[${i}]} ]; then
idx[${i}]=0
else
break
fi
done
local fin
fin=1
# Check to see if we have looped back to the beginning.
for v in ${idx[@]}; do
if [ ${v} != 0 ]; then
fin=0
break
fi
done
if [ ${fin} == 1 ]; then
rm -Rf ${out_dir}
clean_and_die 0
fi
}
clean_and_die() {
rm -Rf ${out_dir}
exit $1
}
run(){
local nr
nr=${rpcs[${idx[0]}]}
local nc
nc=${conns[${idx[1]}]}
req_sz=${reqs[${idx[2]}]}
resp_sz=${resps[${idx[3]}]}
r_type=${rpc_types[${idx[4]}]}
# Following runs one benchmark
base_port=50051
delta=0
test_name="r_"${nr}"_c_"${nc}"_req_"${req_sz}"_resp_"${resp_sz}"_"${r_type}"_"$(date +%s)
echo "================================================================================"
echo ${test_name}
while :
do
port=$((${base_port}+${delta}))
# Launch the server in background
${out_dir}/server --port=${port} --test_name="Server_"${test_name}&
server_pid=$(echo $!)
# Launch the client
${out_dir}/client --port=${port} --d=${dur} --w=${warmup} --r=${nr} --c=${nc} --req=${req_sz} --resp=${resp_sz} --rpc_type=${r_type} --test_name="client_"${test_name}
client_status=$(echo $?)
kill -INT ${server_pid}
wait ${server_pid}
if [ ${client_status} == 0 ]; then
break
fi
delta=$((${delta}+1))
if [ ${delta} == 10 ]; then
echo "Continuous 10 failed runs. Exiting now."
rm -Rf ${out_dir}
clean_and_die 1
fi
done
}
set_param(){
local argname=$1
shift
local idx=$1
shift
if [ $# -eq 0 ]; then
echo "${argname} not specified"
exit 1
fi
PARAM=($(echo $1 | sed 's/,/ /g'))
if [ ${idx} -lt 0 ]; then
return
fi
idx_max[${idx}]=${#PARAM[@]}
}
while [ $# -gt 0 ]; do
case "$1" in
-r)
shift
set_param "number of rpcs" 0 $1
rpcs=(${PARAM[@]})
shift
;;
-c)
shift
set_param "number of connections" 1 $1
conns=(${PARAM[@]})
shift
;;
-w)
shift
set_param "warm-up period" -1 $1
warmup=${PARAM}
shift
;;
-d)
shift
set_param "duration" -1 $1
dur=${PARAM}
shift
;;
-req)
shift
set_param "request size" 2 $1
reqs=(${PARAM[@]})
shift
;;
-resp)
shift
set_param "response size" 3 $1
resps=(${PARAM[@]})
shift
;;
-rpc_type)
shift
set_param "rpc type" 4 $1
rpc_types=(${PARAM[@]})
shift
;;
-h|--help)
echo "Following are valid options:"
echo
echo "-h, --help show brief help"
echo "-w warm-up duration in seconds, default value is 10"
echo "-d benchmark duration in seconds, default value is 60"
echo ""
echo "Each of the following can have multiple comma separated values."
echo ""
echo "-r number of RPCs, default value is 1"
echo "-c number of Connections, default value is 1"
echo "-req req size in bytes, default value is 1"
echo "-resp resp size in bytes, default value is 1"
echo "-rpc_type valid values are unary|streaming, default is unary"
;;
*)
echo "Incorrect option $1"
exit 1
;;
esac
done
# Build server and client
out_dir=$(mktemp -d oss_benchXXX)
go build -o ${out_dir}/server $GOPATH/src/google.golang.org/grpc/benchmark/server/main.go && go build -o ${out_dir}/client $GOPATH/src/google.golang.org/grpc/benchmark/client/main.go
if [ $? != 0 ]; then
clean_and_die 1
fi
while :
do
run
inc
done

View File

@@ -20,32 +20,62 @@ package main
import (
"flag"
"math"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"runtime"
"runtime/pprof"
"time"
"google.golang.org/grpc/benchmark"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
)
var duration = flag.Int("duration", math.MaxInt32, "The duration in seconds to run the benchmark server")
var (
port = flag.String("port", "50051", "Localhost port to listen on.")
testName = flag.String("test_name", "", "Name of the test used for creating profiles.")
)
func main() {
flag.Parse()
go func() {
lis, err := net.Listen("tcp", ":0")
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
grpclog.Println("Server profiling address: ", lis.Addr().String())
if err := http.Serve(lis, nil); err != nil {
grpclog.Fatalf("Failed to serve: %v", err)
}
}()
addr, stopper := benchmark.StartServer(benchmark.ServerInfo{Addr: ":0", Type: "protobuf"}) // listen on all interfaces
grpclog.Println("Server Address: ", addr)
<-time.After(time.Duration(*duration) * time.Second)
stopper()
if *testName == "" {
grpclog.Fatalf("test name not set")
}
lis, err := net.Listen("tcp", ":"+*port)
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
cf, err := os.Create("/tmp/" + *testName + ".cpu")
if err != nil {
grpclog.Fatalf("Failed to create file: %v", err)
}
defer cf.Close()
pprof.StartCPUProfile(cf)
cpuBeg := syscall.GetCPUTime()
// Launch server in a separate goroutine.
stop := benchmark.StartServer(benchmark.ServerInfo{Type: "protobuf", Listener: lis})
// Wait on OS terminate signal.
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
<-ch
cpu := time.Duration(syscall.GetCPUTime() - cpuBeg)
stop()
pprof.StopCPUProfile()
mf, err := os.Create("/tmp/" + *testName + ".mem")
if err != nil {
grpclog.Fatalf("Failed to create file: %v", err)
}
defer mf.Close()
runtime.GC() // materialize all statistics
if err := pprof.WriteHeapProfile(mf); err != nil {
grpclog.Fatalf("Failed to write memory profile: %v", err)
}
fmt.Println("Server CPU utilization:", cpu)
fmt.Println("Server CPU profile:", cf.Name())
fmt.Println("Server Mem Profile:", mf.Name())
}

View File

@@ -39,6 +39,7 @@ type Features struct {
ReqSizeBytes int
RespSizeBytes int
EnableCompressor bool
EnableChannelz bool
}
// String returns the textual output of the Features as string.
@@ -48,6 +49,13 @@ func (f Features) String() string {
f.Latency.String(), f.Kbps, f.Mtu, f.MaxConcurrentCalls, f.ReqSizeBytes, f.RespSizeBytes, f.EnableCompressor)
}
// ConciseString returns the concise textual output of the Features as string, skipping
// setting with default value.
func (f Features) ConciseString() string {
noneEmptyPos := []bool{f.EnableTrace, f.Latency != 0, f.Kbps != 0, f.Mtu != 0, true, true, true, f.EnableCompressor, f.EnableChannelz}
return PartialPrintString(noneEmptyPos, f, false)
}
// PartialPrintString can print certain features with different format.
func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string {
s := ""
@@ -63,7 +71,7 @@ func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string {
linker = "_"
}
if noneEmptyPos[0] {
s += fmt.Sprintf("%sTrace%s%t%s", prefix, linker, f.EnableCompressor, suffix)
s += fmt.Sprintf("%sTrace%s%t%s", prefix, linker, f.EnableTrace, suffix)
}
if shared && f.NetworkMode != "" {
s += fmt.Sprintf("Network: %s \n", f.NetworkMode)
@@ -92,6 +100,9 @@ func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string {
if noneEmptyPos[7] {
s += fmt.Sprintf("%sCompressor%s%t%s", prefix, linker, f.EnableCompressor, suffix)
}
if noneEmptyPos[8] {
s += fmt.Sprintf("%sChannelz%s%t%s", prefix, linker, f.EnableChannelz, suffix)
}
return s
}

View File

@@ -23,7 +23,6 @@ import (
"math"
"runtime"
"sync"
"syscall"
"time"
"golang.org/x/net/context"
@@ -34,6 +33,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
@@ -51,12 +51,12 @@ func (h *lockingHistogram) add(value int64) {
h.histogram.Add(value)
}
// swap sets h.histogram to new, and returns its old value.
func (h *lockingHistogram) swap(new *stats.Histogram) *stats.Histogram {
// swap sets h.histogram to o and returns its old value.
func (h *lockingHistogram) swap(o *stats.Histogram) *stats.Histogram {
h.mu.Lock()
defer h.mu.Unlock()
old := h.histogram
h.histogram = new
h.histogram = o
return old
}
@@ -81,20 +81,20 @@ func printClientConfig(config *testpb.ClientConfig) {
// will always create sync client
// - async client threads.
// - core list
grpclog.Printf(" * client type: %v (ignored, always creates sync client)", config.ClientType)
grpclog.Printf(" * async client threads: %v (ignored)", config.AsyncClientThreads)
grpclog.Infof(" * client type: %v (ignored, always creates sync client)", config.ClientType)
grpclog.Infof(" * async client threads: %v (ignored)", config.AsyncClientThreads)
// TODO: use cores specified by CoreList when setting list of cores is supported in go.
grpclog.Printf(" * core list: %v (ignored)", config.CoreList)
grpclog.Infof(" * core list: %v (ignored)", config.CoreList)
grpclog.Printf(" - security params: %v", config.SecurityParams)
grpclog.Printf(" - core limit: %v", config.CoreLimit)
grpclog.Printf(" - payload config: %v", config.PayloadConfig)
grpclog.Printf(" - rpcs per chann: %v", config.OutstandingRpcsPerChannel)
grpclog.Printf(" - channel number: %v", config.ClientChannels)
grpclog.Printf(" - load params: %v", config.LoadParams)
grpclog.Printf(" - rpc type: %v", config.RpcType)
grpclog.Printf(" - histogram params: %v", config.HistogramParams)
grpclog.Printf(" - server targets: %v", config.ServerTargets)
grpclog.Infof(" - security params: %v", config.SecurityParams)
grpclog.Infof(" - core limit: %v", config.CoreLimit)
grpclog.Infof(" - payload config: %v", config.PayloadConfig)
grpclog.Infof(" - rpcs per chann: %v", config.OutstandingRpcsPerChannel)
grpclog.Infof(" - channel number: %v", config.ClientChannels)
grpclog.Infof(" - load params: %v", config.LoadParams)
grpclog.Infof(" - rpc type: %v", config.RpcType)
grpclog.Infof(" - histogram params: %v", config.HistogramParams)
grpclog.Infof(" - server targets: %v", config.ServerTargets)
}
func setupClientEnv(config *testpb.ClientConfig) {
@@ -118,7 +118,7 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error
case testpb.ClientType_SYNC_CLIENT:
case testpb.ClientType_ASYNC_CLIENT:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknow client type: %v", config.ClientType)
return nil, nil, status.Errorf(codes.InvalidArgument, "unknown client type: %v", config.ClientType)
}
// Check and set security options.
@@ -142,13 +142,13 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error
opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(byteBufCodec{})))
case *testpb.PayloadConfig_SimpleParams:
default:
return nil, nil, status.Errorf(codes.InvalidArgument, "unknow payload config: %v", config.PayloadConfig)
return nil, nil, status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
}
// Create connections.
connCount := int(config.ClientChannels)
conns := make([]*grpc.ClientConn, connCount, connCount)
conns := make([]*grpc.ClientConn, connCount)
for connIndex := 0; connIndex < connCount; connIndex++ {
conns[connIndex] = benchmark.NewClientConn(config.ServerTargets[connIndex%len(config.ServerTargets)], opts...)
}
@@ -177,7 +177,7 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc
payloadRespSize = int(c.SimpleParams.RespSize)
payloadType = "protobuf"
default:
return status.Errorf(codes.InvalidArgument, "unknow payload config: %v", config.PayloadConfig)
return status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
}
@@ -217,9 +217,6 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error)
return nil, err
}
rusage := new(syscall.Rusage)
syscall.Getrusage(syscall.RUSAGE_SELF, rusage)
rpcCountPerConn := int(config.OutstandingRpcsPerChannel)
bc := &benchmarkClient{
histogramOptions: stats.HistogramOptions{
@@ -228,12 +225,12 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error)
BaseBucketSize: (1 + config.HistogramParams.Resolution),
MinValue: 0,
},
lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns), rpcCountPerConn*len(conns)),
lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)),
stop: make(chan bool),
lastResetTime: time.Now(),
closeConns: closeConns,
rusageLastReset: rusage,
rusageLastReset: syscall.GetRusage(),
}
if err = performRPCs(config, conns, bc); err != nil {
@@ -335,12 +332,11 @@ func (bc *benchmarkClient) doCloseLoopStreaming(conns []*grpc.ClientConn, rpcCou
func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
var wallTimeElapsed, uTimeElapsed, sTimeElapsed float64
mergedHistogram := stats.NewHistogram(bc.histogramOptions)
latestRusage := new(syscall.Rusage)
if reset {
// Merging histogram may take some time.
// Put all histograms aside and merge later.
toMerge := make([]*stats.Histogram, len(bc.lockingHistograms), len(bc.lockingHistograms))
toMerge := make([]*stats.Histogram, len(bc.lockingHistograms))
for i := range bc.lockingHistograms {
toMerge[i] = bc.lockingHistograms[i].swap(stats.NewHistogram(bc.histogramOptions))
}
@@ -350,8 +346,8 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
}
wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
syscall.Getrusage(syscall.RUSAGE_SELF, latestRusage)
uTimeElapsed, sTimeElapsed = cpuTimeDiff(bc.rusageLastReset, latestRusage)
latestRusage := syscall.GetRusage()
uTimeElapsed, sTimeElapsed = syscall.CPUTimeDiff(bc.rusageLastReset, latestRusage)
bc.rusageLastReset = latestRusage
bc.lastResetTime = time.Now()
@@ -362,11 +358,10 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats {
}
wallTimeElapsed = time.Since(bc.lastResetTime).Seconds()
syscall.Getrusage(syscall.RUSAGE_SELF, latestRusage)
uTimeElapsed, sTimeElapsed = cpuTimeDiff(bc.rusageLastReset, latestRusage)
uTimeElapsed, sTimeElapsed = syscall.CPUTimeDiff(bc.rusageLastReset, syscall.GetRusage())
}
b := make([]uint32, len(mergedHistogram.Buckets), len(mergedHistogram.Buckets))
b := make([]uint32, len(mergedHistogram.Buckets))
for i, v := range mergedHistogram.Buckets {
b[i] = uint32(v.Count)
}

View File

@@ -20,11 +20,12 @@ package main
import (
"flag"
"fmt"
"net"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"google.golang.org/grpc"
@@ -33,6 +34,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
)
@@ -57,15 +59,15 @@ func printServerConfig(config *testpb.ServerConfig) {
// will always start sync server
// - async server threads
// - core list
grpclog.Printf(" * server type: %v (ignored, always starts sync server)", config.ServerType)
grpclog.Printf(" * async server threads: %v (ignored)", config.AsyncServerThreads)
grpclog.Infof(" * server type: %v (ignored, always starts sync server)", config.ServerType)
grpclog.Infof(" * async server threads: %v (ignored)", config.AsyncServerThreads)
// TODO: use cores specified by CoreList when setting list of cores is supported in go.
grpclog.Printf(" * core list: %v (ignored)", config.CoreList)
grpclog.Infof(" * core list: %v (ignored)", config.CoreList)
grpclog.Printf(" - security params: %v", config.SecurityParams)
grpclog.Printf(" - core limit: %v", config.CoreLimit)
grpclog.Printf(" - port: %v", config.Port)
grpclog.Printf(" - payload config: %v", config.PayloadConfig)
grpclog.Infof(" - security params: %v", config.SecurityParams)
grpclog.Infof(" - core limit: %v", config.CoreLimit)
grpclog.Infof(" - port: %v", config.Port)
grpclog.Infof(" - payload config: %v", config.PayloadConfig)
}
func startBenchmarkServer(config *testpb.ServerConfig, serverPort int) (*benchmarkServer, error) {
@@ -87,7 +89,7 @@ func startBenchmarkServer(config *testpb.ServerConfig, serverPort int) (*benchma
case testpb.ServerType_ASYNC_SERVER:
case testpb.ServerType_ASYNC_GENERIC_SERVER:
default:
return nil, status.Errorf(codes.InvalidArgument, "unknow server type: %v", config.ServerType)
return nil, status.Errorf(codes.InvalidArgument, "unknown server type: %v", config.ServerType)
}
// Set security options.
@@ -110,56 +112,54 @@ func startBenchmarkServer(config *testpb.ServerConfig, serverPort int) (*benchma
if port == 0 {
port = serverPort
}
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
addr := lis.Addr().String()
// Create different benchmark server according to config.
var (
addr string
closeFunc func()
err error
)
var closeFunc func()
if config.PayloadConfig != nil {
switch payload := config.PayloadConfig.Payload.(type) {
case *testpb.PayloadConfig_BytebufParams:
opts = append(opts, grpc.CustomCodec(byteBufCodec{}))
addr, closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Addr: ":" + strconv.Itoa(port),
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "bytebuf",
Metadata: payload.BytebufParams.RespSize,
Listener: lis,
}, opts...)
case *testpb.PayloadConfig_SimpleParams:
addr, closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Addr: ":" + strconv.Itoa(port),
Type: "protobuf",
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "protobuf",
Listener: lis,
}, opts...)
case *testpb.PayloadConfig_ComplexParams:
return nil, status.Errorf(codes.Unimplemented, "unsupported payload config: %v", config.PayloadConfig)
default:
return nil, status.Errorf(codes.InvalidArgument, "unknow payload config: %v", config.PayloadConfig)
return nil, status.Errorf(codes.InvalidArgument, "unknown payload config: %v", config.PayloadConfig)
}
} else {
// Start protobuf server if payload config is nil.
addr, closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Addr: ":" + strconv.Itoa(port),
Type: "protobuf",
closeFunc = benchmark.StartServer(benchmark.ServerInfo{
Type: "protobuf",
Listener: lis,
}, opts...)
}
grpclog.Printf("benchmark server listening at %v", addr)
grpclog.Infof("benchmark server listening at %v", addr)
addrSplitted := strings.Split(addr, ":")
p, err := strconv.Atoi(addrSplitted[len(addrSplitted)-1])
if err != nil {
grpclog.Fatalf("failed to get port number from server address: %v", err)
}
rusage := new(syscall.Rusage)
syscall.Getrusage(syscall.RUSAGE_SELF, rusage)
return &benchmarkServer{
port: p,
cores: numOfCores,
closeFunc: closeFunc,
lastResetTime: time.Now(),
rusageLastReset: rusage,
rusageLastReset: syscall.GetRusage(),
}, nil
}
@@ -169,9 +169,8 @@ func (bs *benchmarkServer) getStats(reset bool) *testpb.ServerStats {
bs.mu.RLock()
defer bs.mu.RUnlock()
wallTimeElapsed := time.Since(bs.lastResetTime).Seconds()
rusageLatest := new(syscall.Rusage)
syscall.Getrusage(syscall.RUSAGE_SELF, rusageLatest)
uTimeElapsed, sTimeElapsed := cpuTimeDiff(bs.rusageLastReset, rusageLatest)
rusageLatest := syscall.GetRusage()
uTimeElapsed, sTimeElapsed := syscall.CPUTimeDiff(bs.rusageLastReset, rusageLatest)
if reset {
bs.lastResetTime = time.Now()

View File

@@ -79,7 +79,7 @@ func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) er
var bs *benchmarkServer
defer func() {
// Close benchmark server when stream ends.
grpclog.Printf("closing benchmark server")
grpclog.Infof("closing benchmark server")
if bs != nil {
bs.closeFunc()
}
@@ -96,9 +96,9 @@ func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) er
var out *testpb.ServerStatus
switch argtype := in.Argtype.(type) {
case *testpb.ServerArgs_Setup:
grpclog.Printf("server setup received:")
grpclog.Infof("server setup received:")
if bs != nil {
grpclog.Printf("server setup received when server already exists, closing the existing server")
grpclog.Infof("server setup received when server already exists, closing the existing server")
bs.closeFunc()
}
bs, err = startBenchmarkServer(argtype.Setup, s.serverPort)
@@ -112,8 +112,8 @@ func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) er
}
case *testpb.ServerArgs_Mark:
grpclog.Printf("server mark received:")
grpclog.Printf(" - %v", argtype)
grpclog.Infof("server mark received:")
grpclog.Infof(" - %v", argtype)
if bs == nil {
return status.Error(codes.InvalidArgument, "server does not exist when mark received")
}
@@ -134,7 +134,7 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er
var bc *benchmarkClient
defer func() {
// Shut down benchmark client when stream ends.
grpclog.Printf("shuting down benchmark client")
grpclog.Infof("shuting down benchmark client")
if bc != nil {
bc.shutdown()
}
@@ -151,9 +151,9 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er
var out *testpb.ClientStatus
switch t := in.Argtype.(type) {
case *testpb.ClientArgs_Setup:
grpclog.Printf("client setup received:")
grpclog.Infof("client setup received:")
if bc != nil {
grpclog.Printf("client setup received when client already exists, shuting down the existing client")
grpclog.Infof("client setup received when client already exists, shuting down the existing client")
bc.shutdown()
}
bc, err = startBenchmarkClient(t.Setup)
@@ -165,8 +165,8 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er
}
case *testpb.ClientArgs_Mark:
grpclog.Printf("client mark received:")
grpclog.Printf(" - %v", t)
grpclog.Infof("client mark received:")
grpclog.Infof(" - %v", t)
if bc == nil {
return status.Error(codes.InvalidArgument, "client does not exist when mark received")
}
@@ -182,12 +182,12 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er
}
func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) {
grpclog.Printf("core count: %v", runtime.NumCPU())
grpclog.Infof("core count: %v", runtime.NumCPU())
return &testpb.CoreResponse{Cores: int32(runtime.NumCPU())}, nil
}
func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) {
grpclog.Printf("quiting worker")
grpclog.Infof("quitting worker")
s.stop <- true
return &testpb.Void{}, nil
}
@@ -200,7 +200,7 @@ func main() {
if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
}
grpclog.Printf("worker listening at port %v", *driverPort)
grpclog.Infof("worker listening at port %v", *driverPort)
s := grpc.NewServer()
stop := make(chan bool)
@@ -221,8 +221,8 @@ func main() {
if *pprofPort >= 0 {
go func() {
grpclog.Println("Starting pprof server on port " + strconv.Itoa(*pprofPort))
grpclog.Println(http.ListenAndServe("localhost:"+strconv.Itoa(*pprofPort), nil))
grpclog.Infoln("Starting pprof server on port " + strconv.Itoa(*pprofPort))
grpclog.Infoln(http.ListenAndServe("localhost:"+strconv.Itoa(*pprofPort), nil))
}()
}

View File

@@ -0,0 +1,900 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/binarylog/grpc_binarylog_v1/binarylog.proto
package grpc_binarylog_v1 // import "google.golang.org/grpc/binarylog/grpc_binarylog_v1"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import duration "github.com/golang/protobuf/ptypes/duration"
import timestamp "github.com/golang/protobuf/ptypes/timestamp"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// Enumerates the type of event
// Note the terminology is different from the RPC semantics
// definition, but the same meaning is expressed here.
type GrpcLogEntry_EventType int32
const (
GrpcLogEntry_EVENT_TYPE_UNKNOWN GrpcLogEntry_EventType = 0
// Header sent from client to server
GrpcLogEntry_EVENT_TYPE_CLIENT_HEADER GrpcLogEntry_EventType = 1
// Header sent from server to client
GrpcLogEntry_EVENT_TYPE_SERVER_HEADER GrpcLogEntry_EventType = 2
// Message sent from client to server
GrpcLogEntry_EVENT_TYPE_CLIENT_MESSAGE GrpcLogEntry_EventType = 3
// Message sent from server to client
GrpcLogEntry_EVENT_TYPE_SERVER_MESSAGE GrpcLogEntry_EventType = 4
// A signal that client is done sending
GrpcLogEntry_EVENT_TYPE_CLIENT_HALF_CLOSE GrpcLogEntry_EventType = 5
// Trailer indicates the end of the RPC.
// On client side, this event means a trailer was either received
// from the network or the gRPC library locally generated a status
// to inform the application about a failure.
// On server side, this event means the server application requested
// to send a trailer. Note: EVENT_TYPE_CANCEL may still arrive after
// this due to races on server side.
GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER GrpcLogEntry_EventType = 6
// A signal that the RPC is cancelled. On client side, this
// indicates the client application requests a cancellation.
// On server side, this indicates that cancellation was detected.
// Note: This marks the end of the RPC. Events may arrive after
// this due to races. For example, on client side a trailer
// may arrive even though the application requested to cancel the RPC.
GrpcLogEntry_EVENT_TYPE_CANCEL GrpcLogEntry_EventType = 7
)
var GrpcLogEntry_EventType_name = map[int32]string{
0: "EVENT_TYPE_UNKNOWN",
1: "EVENT_TYPE_CLIENT_HEADER",
2: "EVENT_TYPE_SERVER_HEADER",
3: "EVENT_TYPE_CLIENT_MESSAGE",
4: "EVENT_TYPE_SERVER_MESSAGE",
5: "EVENT_TYPE_CLIENT_HALF_CLOSE",
6: "EVENT_TYPE_SERVER_TRAILER",
7: "EVENT_TYPE_CANCEL",
}
var GrpcLogEntry_EventType_value = map[string]int32{
"EVENT_TYPE_UNKNOWN": 0,
"EVENT_TYPE_CLIENT_HEADER": 1,
"EVENT_TYPE_SERVER_HEADER": 2,
"EVENT_TYPE_CLIENT_MESSAGE": 3,
"EVENT_TYPE_SERVER_MESSAGE": 4,
"EVENT_TYPE_CLIENT_HALF_CLOSE": 5,
"EVENT_TYPE_SERVER_TRAILER": 6,
"EVENT_TYPE_CANCEL": 7,
}
func (x GrpcLogEntry_EventType) String() string {
return proto.EnumName(GrpcLogEntry_EventType_name, int32(x))
}
func (GrpcLogEntry_EventType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{0, 0}
}
// Enumerates the entity that generates the log entry
type GrpcLogEntry_Logger int32
const (
GrpcLogEntry_LOGGER_UNKNOWN GrpcLogEntry_Logger = 0
GrpcLogEntry_LOGGER_CLIENT GrpcLogEntry_Logger = 1
GrpcLogEntry_LOGGER_SERVER GrpcLogEntry_Logger = 2
)
var GrpcLogEntry_Logger_name = map[int32]string{
0: "LOGGER_UNKNOWN",
1: "LOGGER_CLIENT",
2: "LOGGER_SERVER",
}
var GrpcLogEntry_Logger_value = map[string]int32{
"LOGGER_UNKNOWN": 0,
"LOGGER_CLIENT": 1,
"LOGGER_SERVER": 2,
}
func (x GrpcLogEntry_Logger) String() string {
return proto.EnumName(GrpcLogEntry_Logger_name, int32(x))
}
func (GrpcLogEntry_Logger) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{0, 1}
}
type Address_Type int32
const (
Address_TYPE_UNKNOWN Address_Type = 0
// address is in 1.2.3.4 form
Address_TYPE_IPV4 Address_Type = 1
// address is in IPv6 canonical form (RFC5952 section 4)
// The scope is NOT included in the address string.
Address_TYPE_IPV6 Address_Type = 2
// address is UDS string
Address_TYPE_UNIX Address_Type = 3
)
var Address_Type_name = map[int32]string{
0: "TYPE_UNKNOWN",
1: "TYPE_IPV4",
2: "TYPE_IPV6",
3: "TYPE_UNIX",
}
var Address_Type_value = map[string]int32{
"TYPE_UNKNOWN": 0,
"TYPE_IPV4": 1,
"TYPE_IPV6": 2,
"TYPE_UNIX": 3,
}
func (x Address_Type) String() string {
return proto.EnumName(Address_Type_name, int32(x))
}
func (Address_Type) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{7, 0}
}
// Log entry we store in binary logs
type GrpcLogEntry struct {
// The timestamp of the binary log message
Timestamp *timestamp.Timestamp `protobuf:"bytes,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// Uniquely identifies a call. The value must not be 0 in order to disambiguate
// from an unset value.
// Each call may have several log entries, they will all have the same call_id.
// Nothing is guaranteed about their value other than they are unique across
// different RPCs in the same gRPC process.
CallId uint64 `protobuf:"varint,2,opt,name=call_id,json=callId,proto3" json:"call_id,omitempty"`
// The entry sequence id for this call. The first GrpcLogEntry has a
// value of 1, to disambiguate from an unset value. The purpose of
// this field is to detect missing entries in environments where
// durability or ordering is not guaranteed.
SequenceIdWithinCall uint64 `protobuf:"varint,3,opt,name=sequence_id_within_call,json=sequenceIdWithinCall,proto3" json:"sequence_id_within_call,omitempty"`
Type GrpcLogEntry_EventType `protobuf:"varint,4,opt,name=type,proto3,enum=grpc.binarylog.v1.GrpcLogEntry_EventType" json:"type,omitempty"`
Logger GrpcLogEntry_Logger `protobuf:"varint,5,opt,name=logger,proto3,enum=grpc.binarylog.v1.GrpcLogEntry_Logger" json:"logger,omitempty"`
// The logger uses one of the following fields to record the payload,
// according to the type of the log entry.
//
// Types that are valid to be assigned to Payload:
// *GrpcLogEntry_ClientHeader
// *GrpcLogEntry_ServerHeader
// *GrpcLogEntry_Message
// *GrpcLogEntry_Trailer
Payload isGrpcLogEntry_Payload `protobuf_oneof:"payload"`
// true if payload does not represent the full message or metadata.
PayloadTruncated bool `protobuf:"varint,10,opt,name=payload_truncated,json=payloadTruncated,proto3" json:"payload_truncated,omitempty"`
// Peer address information, will only be recorded on the first
// incoming event. On client side, peer is logged on
// EVENT_TYPE_SERVER_HEADER normally or EVENT_TYPE_SERVER_TRAILER in
// the case of trailers-only. On server side, peer is always
// logged on EVENT_TYPE_CLIENT_HEADER.
Peer *Address `protobuf:"bytes,11,opt,name=peer,proto3" json:"peer,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *GrpcLogEntry) Reset() { *m = GrpcLogEntry{} }
func (m *GrpcLogEntry) String() string { return proto.CompactTextString(m) }
func (*GrpcLogEntry) ProtoMessage() {}
func (*GrpcLogEntry) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{0}
}
func (m *GrpcLogEntry) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_GrpcLogEntry.Unmarshal(m, b)
}
func (m *GrpcLogEntry) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GrpcLogEntry.Marshal(b, m, deterministic)
}
func (dst *GrpcLogEntry) XXX_Merge(src proto.Message) {
xxx_messageInfo_GrpcLogEntry.Merge(dst, src)
}
func (m *GrpcLogEntry) XXX_Size() int {
return xxx_messageInfo_GrpcLogEntry.Size(m)
}
func (m *GrpcLogEntry) XXX_DiscardUnknown() {
xxx_messageInfo_GrpcLogEntry.DiscardUnknown(m)
}
var xxx_messageInfo_GrpcLogEntry proto.InternalMessageInfo
func (m *GrpcLogEntry) GetTimestamp() *timestamp.Timestamp {
if m != nil {
return m.Timestamp
}
return nil
}
func (m *GrpcLogEntry) GetCallId() uint64 {
if m != nil {
return m.CallId
}
return 0
}
func (m *GrpcLogEntry) GetSequenceIdWithinCall() uint64 {
if m != nil {
return m.SequenceIdWithinCall
}
return 0
}
func (m *GrpcLogEntry) GetType() GrpcLogEntry_EventType {
if m != nil {
return m.Type
}
return GrpcLogEntry_EVENT_TYPE_UNKNOWN
}
func (m *GrpcLogEntry) GetLogger() GrpcLogEntry_Logger {
if m != nil {
return m.Logger
}
return GrpcLogEntry_LOGGER_UNKNOWN
}
type isGrpcLogEntry_Payload interface {
isGrpcLogEntry_Payload()
}
type GrpcLogEntry_ClientHeader struct {
ClientHeader *ClientHeader `protobuf:"bytes,6,opt,name=client_header,json=clientHeader,proto3,oneof"`
}
type GrpcLogEntry_ServerHeader struct {
ServerHeader *ServerHeader `protobuf:"bytes,7,opt,name=server_header,json=serverHeader,proto3,oneof"`
}
type GrpcLogEntry_Message struct {
Message *Message `protobuf:"bytes,8,opt,name=message,proto3,oneof"`
}
type GrpcLogEntry_Trailer struct {
Trailer *Trailer `protobuf:"bytes,9,opt,name=trailer,proto3,oneof"`
}
func (*GrpcLogEntry_ClientHeader) isGrpcLogEntry_Payload() {}
func (*GrpcLogEntry_ServerHeader) isGrpcLogEntry_Payload() {}
func (*GrpcLogEntry_Message) isGrpcLogEntry_Payload() {}
func (*GrpcLogEntry_Trailer) isGrpcLogEntry_Payload() {}
func (m *GrpcLogEntry) GetPayload() isGrpcLogEntry_Payload {
if m != nil {
return m.Payload
}
return nil
}
func (m *GrpcLogEntry) GetClientHeader() *ClientHeader {
if x, ok := m.GetPayload().(*GrpcLogEntry_ClientHeader); ok {
return x.ClientHeader
}
return nil
}
func (m *GrpcLogEntry) GetServerHeader() *ServerHeader {
if x, ok := m.GetPayload().(*GrpcLogEntry_ServerHeader); ok {
return x.ServerHeader
}
return nil
}
func (m *GrpcLogEntry) GetMessage() *Message {
if x, ok := m.GetPayload().(*GrpcLogEntry_Message); ok {
return x.Message
}
return nil
}
func (m *GrpcLogEntry) GetTrailer() *Trailer {
if x, ok := m.GetPayload().(*GrpcLogEntry_Trailer); ok {
return x.Trailer
}
return nil
}
func (m *GrpcLogEntry) GetPayloadTruncated() bool {
if m != nil {
return m.PayloadTruncated
}
return false
}
func (m *GrpcLogEntry) GetPeer() *Address {
if m != nil {
return m.Peer
}
return nil
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*GrpcLogEntry) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _GrpcLogEntry_OneofMarshaler, _GrpcLogEntry_OneofUnmarshaler, _GrpcLogEntry_OneofSizer, []interface{}{
(*GrpcLogEntry_ClientHeader)(nil),
(*GrpcLogEntry_ServerHeader)(nil),
(*GrpcLogEntry_Message)(nil),
(*GrpcLogEntry_Trailer)(nil),
}
}
func _GrpcLogEntry_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
m := msg.(*GrpcLogEntry)
// payload
switch x := m.Payload.(type) {
case *GrpcLogEntry_ClientHeader:
b.EncodeVarint(6<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ClientHeader); err != nil {
return err
}
case *GrpcLogEntry_ServerHeader:
b.EncodeVarint(7<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.ServerHeader); err != nil {
return err
}
case *GrpcLogEntry_Message:
b.EncodeVarint(8<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.Message); err != nil {
return err
}
case *GrpcLogEntry_Trailer:
b.EncodeVarint(9<<3 | proto.WireBytes)
if err := b.EncodeMessage(x.Trailer); err != nil {
return err
}
case nil:
default:
return fmt.Errorf("GrpcLogEntry.Payload has unexpected type %T", x)
}
return nil
}
func _GrpcLogEntry_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) {
m := msg.(*GrpcLogEntry)
switch tag {
case 6: // payload.client_header
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ClientHeader)
err := b.DecodeMessage(msg)
m.Payload = &GrpcLogEntry_ClientHeader{msg}
return true, err
case 7: // payload.server_header
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(ServerHeader)
err := b.DecodeMessage(msg)
m.Payload = &GrpcLogEntry_ServerHeader{msg}
return true, err
case 8: // payload.message
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(Message)
err := b.DecodeMessage(msg)
m.Payload = &GrpcLogEntry_Message{msg}
return true, err
case 9: // payload.trailer
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
msg := new(Trailer)
err := b.DecodeMessage(msg)
m.Payload = &GrpcLogEntry_Trailer{msg}
return true, err
default:
return false, nil
}
}
func _GrpcLogEntry_OneofSizer(msg proto.Message) (n int) {
m := msg.(*GrpcLogEntry)
// payload
switch x := m.Payload.(type) {
case *GrpcLogEntry_ClientHeader:
s := proto.Size(x.ClientHeader)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *GrpcLogEntry_ServerHeader:
s := proto.Size(x.ServerHeader)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *GrpcLogEntry_Message:
s := proto.Size(x.Message)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case *GrpcLogEntry_Trailer:
s := proto.Size(x.Trailer)
n += 1 // tag and wire
n += proto.SizeVarint(uint64(s))
n += s
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
}
return n
}
type ClientHeader struct {
// This contains only the metadata from the application.
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
// The name of the RPC method, which looks something like:
// /<service>/<method>
// Note the leading "/" character.
MethodName string `protobuf:"bytes,2,opt,name=method_name,json=methodName,proto3" json:"method_name,omitempty"`
// A single process may be used to run multiple virtual
// servers with different identities.
// The authority is the name of such a server identitiy.
// It is typically a portion of the URI in the form of
// <host> or <host>:<port> .
Authority string `protobuf:"bytes,3,opt,name=authority,proto3" json:"authority,omitempty"`
// the RPC timeout
Timeout *duration.Duration `protobuf:"bytes,4,opt,name=timeout,proto3" json:"timeout,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientHeader) Reset() { *m = ClientHeader{} }
func (m *ClientHeader) String() string { return proto.CompactTextString(m) }
func (*ClientHeader) ProtoMessage() {}
func (*ClientHeader) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{1}
}
func (m *ClientHeader) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientHeader.Unmarshal(m, b)
}
func (m *ClientHeader) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientHeader.Marshal(b, m, deterministic)
}
func (dst *ClientHeader) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientHeader.Merge(dst, src)
}
func (m *ClientHeader) XXX_Size() int {
return xxx_messageInfo_ClientHeader.Size(m)
}
func (m *ClientHeader) XXX_DiscardUnknown() {
xxx_messageInfo_ClientHeader.DiscardUnknown(m)
}
var xxx_messageInfo_ClientHeader proto.InternalMessageInfo
func (m *ClientHeader) GetMetadata() *Metadata {
if m != nil {
return m.Metadata
}
return nil
}
func (m *ClientHeader) GetMethodName() string {
if m != nil {
return m.MethodName
}
return ""
}
func (m *ClientHeader) GetAuthority() string {
if m != nil {
return m.Authority
}
return ""
}
func (m *ClientHeader) GetTimeout() *duration.Duration {
if m != nil {
return m.Timeout
}
return nil
}
type ServerHeader struct {
// This contains only the metadata from the application.
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerHeader) Reset() { *m = ServerHeader{} }
func (m *ServerHeader) String() string { return proto.CompactTextString(m) }
func (*ServerHeader) ProtoMessage() {}
func (*ServerHeader) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{2}
}
func (m *ServerHeader) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerHeader.Unmarshal(m, b)
}
func (m *ServerHeader) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerHeader.Marshal(b, m, deterministic)
}
func (dst *ServerHeader) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerHeader.Merge(dst, src)
}
func (m *ServerHeader) XXX_Size() int {
return xxx_messageInfo_ServerHeader.Size(m)
}
func (m *ServerHeader) XXX_DiscardUnknown() {
xxx_messageInfo_ServerHeader.DiscardUnknown(m)
}
var xxx_messageInfo_ServerHeader proto.InternalMessageInfo
func (m *ServerHeader) GetMetadata() *Metadata {
if m != nil {
return m.Metadata
}
return nil
}
type Trailer struct {
// This contains only the metadata from the application.
Metadata *Metadata `protobuf:"bytes,1,opt,name=metadata,proto3" json:"metadata,omitempty"`
// The gRPC status code.
StatusCode uint32 `protobuf:"varint,2,opt,name=status_code,json=statusCode,proto3" json:"status_code,omitempty"`
// An original status message before any transport specific
// encoding.
StatusMessage string `protobuf:"bytes,3,opt,name=status_message,json=statusMessage,proto3" json:"status_message,omitempty"`
// The value of the 'grpc-status-details-bin' metadata key. If
// present, this is always an encoded 'google.rpc.Status' message.
StatusDetails []byte `protobuf:"bytes,4,opt,name=status_details,json=statusDetails,proto3" json:"status_details,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Trailer) Reset() { *m = Trailer{} }
func (m *Trailer) String() string { return proto.CompactTextString(m) }
func (*Trailer) ProtoMessage() {}
func (*Trailer) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{3}
}
func (m *Trailer) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Trailer.Unmarshal(m, b)
}
func (m *Trailer) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Trailer.Marshal(b, m, deterministic)
}
func (dst *Trailer) XXX_Merge(src proto.Message) {
xxx_messageInfo_Trailer.Merge(dst, src)
}
func (m *Trailer) XXX_Size() int {
return xxx_messageInfo_Trailer.Size(m)
}
func (m *Trailer) XXX_DiscardUnknown() {
xxx_messageInfo_Trailer.DiscardUnknown(m)
}
var xxx_messageInfo_Trailer proto.InternalMessageInfo
func (m *Trailer) GetMetadata() *Metadata {
if m != nil {
return m.Metadata
}
return nil
}
func (m *Trailer) GetStatusCode() uint32 {
if m != nil {
return m.StatusCode
}
return 0
}
func (m *Trailer) GetStatusMessage() string {
if m != nil {
return m.StatusMessage
}
return ""
}
func (m *Trailer) GetStatusDetails() []byte {
if m != nil {
return m.StatusDetails
}
return nil
}
// Message payload, used by CLIENT_MESSAGE and SERVER_MESSAGE
type Message struct {
// Length of the message. It may not be the same as the length of the
// data field, as the logging payload can be truncated or omitted.
Length uint32 `protobuf:"varint,1,opt,name=length,proto3" json:"length,omitempty"`
// May be truncated or omitted.
Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Message) Reset() { *m = Message{} }
func (m *Message) String() string { return proto.CompactTextString(m) }
func (*Message) ProtoMessage() {}
func (*Message) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{4}
}
func (m *Message) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Message.Unmarshal(m, b)
}
func (m *Message) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Message.Marshal(b, m, deterministic)
}
func (dst *Message) XXX_Merge(src proto.Message) {
xxx_messageInfo_Message.Merge(dst, src)
}
func (m *Message) XXX_Size() int {
return xxx_messageInfo_Message.Size(m)
}
func (m *Message) XXX_DiscardUnknown() {
xxx_messageInfo_Message.DiscardUnknown(m)
}
var xxx_messageInfo_Message proto.InternalMessageInfo
func (m *Message) GetLength() uint32 {
if m != nil {
return m.Length
}
return 0
}
func (m *Message) GetData() []byte {
if m != nil {
return m.Data
}
return nil
}
// A list of metadata pairs, used in the payload of client header,
// server header, and server trailer.
// Implementations may omit some entries to honor the header limits
// of GRPC_BINARY_LOG_CONFIG.
//
// Header keys added by gRPC are omitted. To be more specific,
// implementations will not log the following entries, and this is
// not to be treated as a truncation:
// - entries handled by grpc that are not user visible, such as those
// that begin with 'grpc-' (with exception of grpc-trace-bin)
// or keys like 'lb-token'
// - transport specific entries, including but not limited to:
// ':path', ':authority', 'content-encoding', 'user-agent', 'te', etc
// - entries added for call credentials
//
// Implementations must always log grpc-trace-bin if it is present.
// Practically speaking it will only be visible on server side because
// grpc-trace-bin is managed by low level client side mechanisms
// inaccessible from the application level. On server side, the
// header is just a normal metadata key.
// The pair will not count towards the size limit.
type Metadata struct {
Entry []*MetadataEntry `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Metadata) Reset() { *m = Metadata{} }
func (m *Metadata) String() string { return proto.CompactTextString(m) }
func (*Metadata) ProtoMessage() {}
func (*Metadata) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{5}
}
func (m *Metadata) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Metadata.Unmarshal(m, b)
}
func (m *Metadata) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Metadata.Marshal(b, m, deterministic)
}
func (dst *Metadata) XXX_Merge(src proto.Message) {
xxx_messageInfo_Metadata.Merge(dst, src)
}
func (m *Metadata) XXX_Size() int {
return xxx_messageInfo_Metadata.Size(m)
}
func (m *Metadata) XXX_DiscardUnknown() {
xxx_messageInfo_Metadata.DiscardUnknown(m)
}
var xxx_messageInfo_Metadata proto.InternalMessageInfo
func (m *Metadata) GetEntry() []*MetadataEntry {
if m != nil {
return m.Entry
}
return nil
}
// A metadata key value pair
type MetadataEntry struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *MetadataEntry) Reset() { *m = MetadataEntry{} }
func (m *MetadataEntry) String() string { return proto.CompactTextString(m) }
func (*MetadataEntry) ProtoMessage() {}
func (*MetadataEntry) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{6}
}
func (m *MetadataEntry) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_MetadataEntry.Unmarshal(m, b)
}
func (m *MetadataEntry) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_MetadataEntry.Marshal(b, m, deterministic)
}
func (dst *MetadataEntry) XXX_Merge(src proto.Message) {
xxx_messageInfo_MetadataEntry.Merge(dst, src)
}
func (m *MetadataEntry) XXX_Size() int {
return xxx_messageInfo_MetadataEntry.Size(m)
}
func (m *MetadataEntry) XXX_DiscardUnknown() {
xxx_messageInfo_MetadataEntry.DiscardUnknown(m)
}
var xxx_messageInfo_MetadataEntry proto.InternalMessageInfo
func (m *MetadataEntry) GetKey() string {
if m != nil {
return m.Key
}
return ""
}
func (m *MetadataEntry) GetValue() []byte {
if m != nil {
return m.Value
}
return nil
}
// Address information
type Address struct {
Type Address_Type `protobuf:"varint,1,opt,name=type,proto3,enum=grpc.binarylog.v1.Address_Type" json:"type,omitempty"`
Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
// only for TYPE_IPV4 and TYPE_IPV6
IpPort uint32 `protobuf:"varint,3,opt,name=ip_port,json=ipPort,proto3" json:"ip_port,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Address) Reset() { *m = Address{} }
func (m *Address) String() string { return proto.CompactTextString(m) }
func (*Address) ProtoMessage() {}
func (*Address) Descriptor() ([]byte, []int) {
return fileDescriptor_binarylog_264c8c9c551ce911, []int{7}
}
func (m *Address) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Address.Unmarshal(m, b)
}
func (m *Address) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Address.Marshal(b, m, deterministic)
}
func (dst *Address) XXX_Merge(src proto.Message) {
xxx_messageInfo_Address.Merge(dst, src)
}
func (m *Address) XXX_Size() int {
return xxx_messageInfo_Address.Size(m)
}
func (m *Address) XXX_DiscardUnknown() {
xxx_messageInfo_Address.DiscardUnknown(m)
}
var xxx_messageInfo_Address proto.InternalMessageInfo
func (m *Address) GetType() Address_Type {
if m != nil {
return m.Type
}
return Address_TYPE_UNKNOWN
}
func (m *Address) GetAddress() string {
if m != nil {
return m.Address
}
return ""
}
func (m *Address) GetIpPort() uint32 {
if m != nil {
return m.IpPort
}
return 0
}
func init() {
proto.RegisterType((*GrpcLogEntry)(nil), "grpc.binarylog.v1.GrpcLogEntry")
proto.RegisterType((*ClientHeader)(nil), "grpc.binarylog.v1.ClientHeader")
proto.RegisterType((*ServerHeader)(nil), "grpc.binarylog.v1.ServerHeader")
proto.RegisterType((*Trailer)(nil), "grpc.binarylog.v1.Trailer")
proto.RegisterType((*Message)(nil), "grpc.binarylog.v1.Message")
proto.RegisterType((*Metadata)(nil), "grpc.binarylog.v1.Metadata")
proto.RegisterType((*MetadataEntry)(nil), "grpc.binarylog.v1.MetadataEntry")
proto.RegisterType((*Address)(nil), "grpc.binarylog.v1.Address")
proto.RegisterEnum("grpc.binarylog.v1.GrpcLogEntry_EventType", GrpcLogEntry_EventType_name, GrpcLogEntry_EventType_value)
proto.RegisterEnum("grpc.binarylog.v1.GrpcLogEntry_Logger", GrpcLogEntry_Logger_name, GrpcLogEntry_Logger_value)
proto.RegisterEnum("grpc.binarylog.v1.Address_Type", Address_Type_name, Address_Type_value)
}
func init() {
proto.RegisterFile("grpc/binarylog/grpc_binarylog_v1/binarylog.proto", fileDescriptor_binarylog_264c8c9c551ce911)
}
var fileDescriptor_binarylog_264c8c9c551ce911 = []byte{
// 900 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x55, 0x51, 0x6f, 0xe3, 0x44,
0x10, 0x3e, 0x37, 0x69, 0xdc, 0x4c, 0x92, 0xca, 0x5d, 0x95, 0x3b, 0x5f, 0x29, 0x34, 0xb2, 0x04,
0x0a, 0x42, 0x72, 0xb9, 0x94, 0xeb, 0xf1, 0x02, 0x52, 0x92, 0xfa, 0xd2, 0x88, 0x5c, 0x1a, 0x6d,
0x72, 0x3d, 0x40, 0x48, 0xd6, 0x36, 0x5e, 0x1c, 0x0b, 0xc7, 0x6b, 0xd6, 0x9b, 0xa0, 0xfc, 0x2c,
0xde, 0x90, 0xee, 0x77, 0xf1, 0x8e, 0xbc, 0x6b, 0x27, 0xa6, 0x69, 0x0f, 0x09, 0xde, 0x3c, 0xdf,
0x7c, 0xf3, 0xcd, 0xee, 0x78, 0x66, 0x16, 0xbe, 0xf2, 0x79, 0x3c, 0x3b, 0xbf, 0x0b, 0x22, 0xc2,
0xd7, 0x21, 0xf3, 0xcf, 0x53, 0xd3, 0xdd, 0x98, 0xee, 0xea, 0xc5, 0xd6, 0x67, 0xc7, 0x9c, 0x09,
0x86, 0x8e, 0x52, 0x8a, 0xbd, 0x45, 0x57, 0x2f, 0x4e, 0x3e, 0xf5, 0x19, 0xf3, 0x43, 0x7a, 0x2e,
0x09, 0x77, 0xcb, 0x5f, 0xce, 0xbd, 0x25, 0x27, 0x22, 0x60, 0x91, 0x0a, 0x39, 0x39, 0xbb, 0xef,
0x17, 0xc1, 0x82, 0x26, 0x82, 0x2c, 0x62, 0x45, 0xb0, 0xde, 0xeb, 0x50, 0xef, 0xf3, 0x78, 0x36,
0x64, 0xbe, 0x13, 0x09, 0xbe, 0x46, 0xdf, 0x40, 0x75, 0xc3, 0x31, 0xb5, 0xa6, 0xd6, 0xaa, 0xb5,
0x4f, 0x6c, 0xa5, 0x62, 0xe7, 0x2a, 0xf6, 0x34, 0x67, 0xe0, 0x2d, 0x19, 0x3d, 0x03, 0x7d, 0x46,
0xc2, 0xd0, 0x0d, 0x3c, 0x73, 0xaf, 0xa9, 0xb5, 0xca, 0xb8, 0x92, 0x9a, 0x03, 0x0f, 0xbd, 0x84,
0x67, 0x09, 0xfd, 0x6d, 0x49, 0xa3, 0x19, 0x75, 0x03, 0xcf, 0xfd, 0x3d, 0x10, 0xf3, 0x20, 0x72,
0x53, 0xa7, 0x59, 0x92, 0xc4, 0xe3, 0xdc, 0x3d, 0xf0, 0xde, 0x49, 0x67, 0x8f, 0x84, 0x21, 0xfa,
0x16, 0xca, 0x62, 0x1d, 0x53, 0xb3, 0xdc, 0xd4, 0x5a, 0x87, 0xed, 0x2f, 0xec, 0x9d, 0xdb, 0xdb,
0xc5, 0x83, 0xdb, 0xce, 0x8a, 0x46, 0x62, 0xba, 0x8e, 0x29, 0x96, 0x61, 0xe8, 0x3b, 0xa8, 0x84,
0xcc, 0xf7, 0x29, 0x37, 0xf7, 0xa5, 0xc0, 0xe7, 0xff, 0x26, 0x30, 0x94, 0x6c, 0x9c, 0x45, 0xa1,
0xd7, 0xd0, 0x98, 0x85, 0x01, 0x8d, 0x84, 0x3b, 0xa7, 0xc4, 0xa3, 0xdc, 0xac, 0xc8, 0x62, 0x9c,
0x3d, 0x20, 0xd3, 0x93, 0xbc, 0x6b, 0x49, 0xbb, 0x7e, 0x82, 0xeb, 0xb3, 0x82, 0x9d, 0xea, 0x24,
0x94, 0xaf, 0x28, 0xcf, 0x75, 0xf4, 0x47, 0x75, 0x26, 0x92, 0xb7, 0xd5, 0x49, 0x0a, 0x36, 0xba,
0x04, 0x7d, 0x41, 0x93, 0x84, 0xf8, 0xd4, 0x3c, 0xc8, 0x7f, 0xcb, 0x8e, 0xc2, 0x1b, 0xc5, 0xb8,
0x7e, 0x82, 0x73, 0x72, 0x1a, 0x27, 0x38, 0x09, 0x42, 0xca, 0xcd, 0xea, 0xa3, 0x71, 0x53, 0xc5,
0x48, 0xe3, 0x32, 0x32, 0xfa, 0x12, 0x8e, 0x62, 0xb2, 0x0e, 0x19, 0xf1, 0x5c, 0xc1, 0x97, 0xd1,
0x8c, 0x08, 0xea, 0x99, 0xd0, 0xd4, 0x5a, 0x07, 0xd8, 0xc8, 0x1c, 0xd3, 0x1c, 0x47, 0x36, 0x94,
0x63, 0x4a, 0xb9, 0x59, 0x7b, 0x34, 0x43, 0xc7, 0xf3, 0x38, 0x4d, 0x12, 0x2c, 0x79, 0xd6, 0x5f,
0x1a, 0x54, 0x37, 0x3f, 0x0c, 0x3d, 0x05, 0xe4, 0xdc, 0x3a, 0xa3, 0xa9, 0x3b, 0xfd, 0x71, 0xec,
0xb8, 0x6f, 0x47, 0xdf, 0x8f, 0x6e, 0xde, 0x8d, 0x8c, 0x27, 0xe8, 0x14, 0xcc, 0x02, 0xde, 0x1b,
0x0e, 0xd2, 0xef, 0x6b, 0xa7, 0x73, 0xe5, 0x60, 0x43, 0xbb, 0xe7, 0x9d, 0x38, 0xf8, 0xd6, 0xc1,
0xb9, 0x77, 0x0f, 0x7d, 0x02, 0xcf, 0x77, 0x63, 0xdf, 0x38, 0x93, 0x49, 0xa7, 0xef, 0x18, 0xa5,
0x7b, 0xee, 0x2c, 0x38, 0x77, 0x97, 0x51, 0x13, 0x4e, 0x1f, 0xc8, 0xdc, 0x19, 0xbe, 0x76, 0x7b,
0xc3, 0x9b, 0x89, 0x63, 0xec, 0x3f, 0x2c, 0x30, 0xc5, 0x9d, 0xc1, 0xd0, 0xc1, 0x46, 0x05, 0x7d,
0x04, 0x47, 0x45, 0x81, 0xce, 0xa8, 0xe7, 0x0c, 0x0d, 0xdd, 0xea, 0x42, 0x45, 0xb5, 0x19, 0x42,
0x70, 0x38, 0xbc, 0xe9, 0xf7, 0x1d, 0x5c, 0xb8, 0xef, 0x11, 0x34, 0x32, 0x4c, 0x65, 0x34, 0xb4,
0x02, 0xa4, 0x52, 0x18, 0x7b, 0xdd, 0x2a, 0xe8, 0x59, 0xfd, 0xad, 0xf7, 0x1a, 0xd4, 0x8b, 0xcd,
0x87, 0x5e, 0xc1, 0xc1, 0x82, 0x0a, 0xe2, 0x11, 0x41, 0xb2, 0xe1, 0xfd, 0xf8, 0xc1, 0x2e, 0x51,
0x14, 0xbc, 0x21, 0xa3, 0x33, 0xa8, 0x2d, 0xa8, 0x98, 0x33, 0xcf, 0x8d, 0xc8, 0x82, 0xca, 0x01,
0xae, 0x62, 0x50, 0xd0, 0x88, 0x2c, 0x28, 0x3a, 0x85, 0x2a, 0x59, 0x8a, 0x39, 0xe3, 0x81, 0x58,
0xcb, 0xb1, 0xad, 0xe2, 0x2d, 0x80, 0x2e, 0x40, 0x4f, 0x17, 0x01, 0x5b, 0x0a, 0x39, 0xae, 0xb5,
0xf6, 0xf3, 0x9d, 0x9d, 0x71, 0x95, 0x6d, 0x26, 0x9c, 0x33, 0xad, 0x3e, 0xd4, 0x8b, 0x1d, 0xff,
0x9f, 0x0f, 0x6f, 0xfd, 0xa1, 0x81, 0x9e, 0x75, 0xf0, 0xff, 0xaa, 0x40, 0x22, 0x88, 0x58, 0x26,
0xee, 0x8c, 0x79, 0xaa, 0x02, 0x0d, 0x0c, 0x0a, 0xea, 0x31, 0x8f, 0xa2, 0xcf, 0xe0, 0x30, 0x23,
0xe4, 0x73, 0xa8, 0xca, 0xd0, 0x50, 0x68, 0x36, 0x7a, 0x05, 0x9a, 0x47, 0x05, 0x09, 0xc2, 0x44,
0x56, 0xa4, 0x9e, 0xd3, 0xae, 0x14, 0x68, 0xbd, 0x04, 0x3d, 0x8f, 0x78, 0x0a, 0x95, 0x90, 0x46,
0xbe, 0x98, 0xcb, 0x03, 0x37, 0x70, 0x66, 0x21, 0x04, 0x65, 0x79, 0x8d, 0x3d, 0x19, 0x2f, 0xbf,
0xad, 0x2e, 0x1c, 0xe4, 0x67, 0x47, 0x97, 0xb0, 0x4f, 0xd3, 0xcd, 0x65, 0x6a, 0xcd, 0x52, 0xab,
0xd6, 0x6e, 0x7e, 0xe0, 0x9e, 0x72, 0xc3, 0x61, 0x45, 0xb7, 0x5e, 0x41, 0xe3, 0x1f, 0x38, 0x32,
0xa0, 0xf4, 0x2b, 0x5d, 0xcb, 0xec, 0x55, 0x9c, 0x7e, 0xa2, 0x63, 0xd8, 0x5f, 0x91, 0x70, 0x49,
0xb3, 0xdc, 0xca, 0xb0, 0xfe, 0xd4, 0x40, 0xcf, 0xe6, 0x18, 0x5d, 0x64, 0xdb, 0x59, 0x93, 0xcb,
0xf5, 0xec, 0xf1, 0x89, 0xb7, 0x0b, 0x3b, 0xd9, 0x04, 0x9d, 0x28, 0x34, 0xeb, 0xb0, 0xdc, 0x4c,
0x1f, 0x8f, 0x20, 0x76, 0x63, 0xc6, 0x85, 0xac, 0x6a, 0x03, 0x57, 0x82, 0x78, 0xcc, 0xb8, 0xb0,
0x1c, 0x28, 0xcb, 0x1d, 0x61, 0x40, 0xfd, 0xde, 0x76, 0x68, 0x40, 0x55, 0x22, 0x83, 0xf1, 0xed,
0xd7, 0x86, 0x56, 0x34, 0x2f, 0x8d, 0xbd, 0x8d, 0xf9, 0x76, 0x34, 0xf8, 0xc1, 0x28, 0x75, 0x7f,
0x86, 0xe3, 0x80, 0xed, 0x1e, 0xb2, 0x7b, 0xd8, 0x95, 0xd6, 0x90, 0xf9, 0xe3, 0xb4, 0x51, 0xc7,
0xda, 0x4f, 0xed, 0xac, 0x71, 0x7d, 0x16, 0x92, 0xc8, 0xb7, 0x19, 0x57, 0x4f, 0xf3, 0x87, 0x5e,
0xea, 0xbb, 0x8a, 0xec, 0xf2, 0x8b, 0xbf, 0x03, 0x00, 0x00, 0xff, 0xff, 0xe7, 0xf6, 0x4b, 0x50,
0xd4, 0x07, 0x00, 0x00,
}

View File

@@ -27,12 +27,31 @@ import (
//
// All errors returned by Invoke are compatible with the status package.
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}
func combine(o1 []CallOption, o2 []CallOption) []CallOption {
// we don't use append because o1 could have extra capacity whose
// elements would be overwritten, which could cause inadvertent
// sharing (and race connditions) between concurrent calls
if len(o1) == 0 {
return o2
} else if len(o2) == 0 {
return o1
}
ret := make([]CallOption, len(o1)+len(o2))
copy(ret, o1)
copy(ret[len(o1):], o2)
return ret
}
// Invoke sends the RPC request on the wire and returns after response is
// received. This is typically called by generated code.
//
@@ -44,31 +63,12 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
var unaryStreamDesc = &StreamDesc{ServerStreams: false, ClientStreams: false}
func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
// TODO: implement retries in clientStream and make this simply
// newClientStream, SendMsg, RecvMsg.
firstAttempt := true
for {
csInt, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
if err != nil {
return err
}
cs := csInt.(*clientStream)
if err := cs.SendMsg(req); err != nil {
if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
return err
}
if err := cs.RecvMsg(reply); err != nil {
if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt {
// TODO: Add a field to header for grpc-transparent-retry-attempts
firstAttempt = false
continue
}
return err
}
return nil
cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
if err != nil {
return err
}
if err := cs.SendMsg(req); err != nil {
return err
}
return cs.RecvMsg(reply)
}

View File

@@ -31,9 +31,9 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/leakcheck"
"google.golang.org/grpc/transport"
)
var (
@@ -105,12 +105,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
}
}
// send a response back to end the stream.
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
data, err := encode(testCodec{}, &expectedResponse)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
}
h.t.Write(s, hdr, data, &transport.Options{})
hdr, payload := msgHeader(data, nil)
h.t.Write(s, hdr, payload, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
@@ -217,7 +218,7 @@ func TestInvoke(t *testing.T) {
defer leakcheck.Check(t)
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
cc.Close()
@@ -229,7 +230,7 @@ func TestInvokeLargeErr(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
@@ -246,7 +247,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply)
if _, ok := status.FromError(err); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
}
@@ -266,7 +267,7 @@ func TestInvokeCancel(t *testing.T) {
for i := 0; i < 100; i++ {
ctx, cancel := context.WithCancel(context.Background())
cancel()
Invoke(ctx, "/foo/bar", &req, &reply, cc)
cc.Invoke(ctx, "/foo/bar", &req, &reply)
}
if canceled != 0 {
t.Fatalf("received %d of 100 canceled requests", canceled)
@@ -285,7 +286,7 @@ func TestInvokeCancelClosedNonFailFast(t *testing.T) {
req := "hello"
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := Invoke(ctx, "/foo/bar", &req, &reply, cc, FailFast(false)); err == nil {
if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, FailFast(false)); err == nil {
t.Fatalf("canceled invoke on closed connection should fail")
}
server.stop()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
// +build !appengine,go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"github.com/golang/protobuf/ptypes"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption {
var opts []*channelzpb.SocketOption
if skopts.Linger != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionLinger{
Active: skopts.Linger.Onoff != 0,
Duration: convertToPtypesDuration(int64(skopts.Linger.Linger), 0),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_LINGER",
Additional: additional,
})
}
}
if skopts.RecvTimeout != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{
Duration: convertToPtypesDuration(int64(skopts.RecvTimeout.Sec), int64(skopts.RecvTimeout.Usec)),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_RCVTIMEO",
Additional: additional,
})
}
}
if skopts.SendTimeout != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTimeout{
Duration: convertToPtypesDuration(int64(skopts.SendTimeout.Sec), int64(skopts.SendTimeout.Usec)),
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "SO_SNDTIMEO",
Additional: additional,
})
}
}
if skopts.TCPInfo != nil {
additional, err := ptypes.MarshalAny(&channelzpb.SocketOptionTcpInfo{
TcpiState: uint32(skopts.TCPInfo.State),
TcpiCaState: uint32(skopts.TCPInfo.Ca_state),
TcpiRetransmits: uint32(skopts.TCPInfo.Retransmits),
TcpiProbes: uint32(skopts.TCPInfo.Probes),
TcpiBackoff: uint32(skopts.TCPInfo.Backoff),
TcpiOptions: uint32(skopts.TCPInfo.Options),
// https://golang.org/pkg/syscall/#TCPInfo
// TCPInfo struct does not contain info about TcpiSndWscale and TcpiRcvWscale.
TcpiRto: skopts.TCPInfo.Rto,
TcpiAto: skopts.TCPInfo.Ato,
TcpiSndMss: skopts.TCPInfo.Snd_mss,
TcpiRcvMss: skopts.TCPInfo.Rcv_mss,
TcpiUnacked: skopts.TCPInfo.Unacked,
TcpiSacked: skopts.TCPInfo.Sacked,
TcpiLost: skopts.TCPInfo.Lost,
TcpiRetrans: skopts.TCPInfo.Retrans,
TcpiFackets: skopts.TCPInfo.Fackets,
TcpiLastDataSent: skopts.TCPInfo.Last_data_sent,
TcpiLastAckSent: skopts.TCPInfo.Last_ack_sent,
TcpiLastDataRecv: skopts.TCPInfo.Last_data_recv,
TcpiLastAckRecv: skopts.TCPInfo.Last_ack_recv,
TcpiPmtu: skopts.TCPInfo.Pmtu,
TcpiRcvSsthresh: skopts.TCPInfo.Rcv_ssthresh,
TcpiRtt: skopts.TCPInfo.Rtt,
TcpiRttvar: skopts.TCPInfo.Rttvar,
TcpiSndSsthresh: skopts.TCPInfo.Snd_ssthresh,
TcpiSndCwnd: skopts.TCPInfo.Snd_cwnd,
TcpiAdvmss: skopts.TCPInfo.Advmss,
TcpiReordering: skopts.TCPInfo.Reordering,
})
if err == nil {
opts = append(opts, &channelzpb.SocketOption{
Name: "TCP_INFO",
Additional: additional,
})
}
}
return opts
}

View File

@@ -0,0 +1,30 @@
// +build !linux appengine !go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func sockoptToProto(skopts *channelz.SocketOptionData) []*channelzpb.SocketOption {
return nil
}

33
vendor/google.golang.org/grpc/channelz/service/regenerate.sh generated vendored Executable file
View File

@@ -0,0 +1,33 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/channelz/v1
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/channelz/v1/channelz.proto > grpc/channelz/v1/channelz.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/channelz/v1/*.proto
popd
rm -f ../grpc_channelz_v1/*.pb.go
cp "$TMP"/grpc/channelz/v1/*.pb.go ../grpc_channelz_v1/

View File

@@ -0,0 +1,335 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package service provides an implementation for channelz service server.
package service
import (
"net"
"time"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
wrpb "github.com/golang/protobuf/ptypes/wrappers"
"golang.org/x/net/context"
"google.golang.org/grpc"
channelzgrpc "google.golang.org/grpc/channelz/grpc_channelz_v1"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
)
func init() {
channelz.TurnOn()
}
func convertToPtypesDuration(sec int64, usec int64) *durpb.Duration {
return ptypes.DurationProto(time.Duration(sec*1e9 + usec*1e3))
}
// RegisterChannelzServiceToServer registers the channelz service to the given server.
func RegisterChannelzServiceToServer(s *grpc.Server) {
channelzgrpc.RegisterChannelzServer(s, newCZServer())
}
func newCZServer() channelzgrpc.ChannelzServer {
return &serverImpl{}
}
type serverImpl struct{}
func connectivityStateToProto(s connectivity.State) *channelzpb.ChannelConnectivityState {
switch s {
case connectivity.Idle:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_IDLE}
case connectivity.Connecting:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_CONNECTING}
case connectivity.Ready:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_READY}
case connectivity.TransientFailure:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_TRANSIENT_FAILURE}
case connectivity.Shutdown:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_SHUTDOWN}
default:
return &channelzpb.ChannelConnectivityState{State: channelzpb.ChannelConnectivityState_UNKNOWN}
}
}
func channelTraceToProto(ct *channelz.ChannelTrace) *channelzpb.ChannelTrace {
pbt := &channelzpb.ChannelTrace{}
pbt.NumEventsLogged = ct.EventNum
if ts, err := ptypes.TimestampProto(ct.CreationTime); err == nil {
pbt.CreationTimestamp = ts
}
var events []*channelzpb.ChannelTraceEvent
for _, e := range ct.Events {
cte := &channelzpb.ChannelTraceEvent{
Description: e.Desc,
Severity: channelzpb.ChannelTraceEvent_Severity(e.Severity),
}
if ts, err := ptypes.TimestampProto(e.Timestamp); err == nil {
cte.Timestamp = ts
}
if e.RefID != 0 {
switch e.RefType {
case channelz.RefChannel:
cte.ChildRef = &channelzpb.ChannelTraceEvent_ChannelRef{ChannelRef: &channelzpb.ChannelRef{ChannelId: e.RefID, Name: e.RefName}}
case channelz.RefSubChannel:
cte.ChildRef = &channelzpb.ChannelTraceEvent_SubchannelRef{SubchannelRef: &channelzpb.SubchannelRef{SubchannelId: e.RefID, Name: e.RefName}}
}
}
events = append(events, cte)
}
pbt.Events = events
return pbt
}
func channelMetricToProto(cm *channelz.ChannelMetric) *channelzpb.Channel {
c := &channelzpb.Channel{}
c.Ref = &channelzpb.ChannelRef{ChannelId: cm.ID, Name: cm.RefName}
c.Data = &channelzpb.ChannelData{
State: connectivityStateToProto(cm.ChannelData.State),
Target: cm.ChannelData.Target,
CallsStarted: cm.ChannelData.CallsStarted,
CallsSucceeded: cm.ChannelData.CallsSucceeded,
CallsFailed: cm.ChannelData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(cm.ChannelData.LastCallStartedTimestamp); err == nil {
c.Data.LastCallStartedTimestamp = ts
}
nestedChans := make([]*channelzpb.ChannelRef, 0, len(cm.NestedChans))
for id, ref := range cm.NestedChans {
nestedChans = append(nestedChans, &channelzpb.ChannelRef{ChannelId: id, Name: ref})
}
c.ChannelRef = nestedChans
subChans := make([]*channelzpb.SubchannelRef, 0, len(cm.SubChans))
for id, ref := range cm.SubChans {
subChans = append(subChans, &channelzpb.SubchannelRef{SubchannelId: id, Name: ref})
}
c.SubchannelRef = subChans
sockets := make([]*channelzpb.SocketRef, 0, len(cm.Sockets))
for id, ref := range cm.Sockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
c.SocketRef = sockets
c.Data.Trace = channelTraceToProto(cm.Trace)
return c
}
func subChannelMetricToProto(cm *channelz.SubChannelMetric) *channelzpb.Subchannel {
sc := &channelzpb.Subchannel{}
sc.Ref = &channelzpb.SubchannelRef{SubchannelId: cm.ID, Name: cm.RefName}
sc.Data = &channelzpb.ChannelData{
State: connectivityStateToProto(cm.ChannelData.State),
Target: cm.ChannelData.Target,
CallsStarted: cm.ChannelData.CallsStarted,
CallsSucceeded: cm.ChannelData.CallsSucceeded,
CallsFailed: cm.ChannelData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(cm.ChannelData.LastCallStartedTimestamp); err == nil {
sc.Data.LastCallStartedTimestamp = ts
}
nestedChans := make([]*channelzpb.ChannelRef, 0, len(cm.NestedChans))
for id, ref := range cm.NestedChans {
nestedChans = append(nestedChans, &channelzpb.ChannelRef{ChannelId: id, Name: ref})
}
sc.ChannelRef = nestedChans
subChans := make([]*channelzpb.SubchannelRef, 0, len(cm.SubChans))
for id, ref := range cm.SubChans {
subChans = append(subChans, &channelzpb.SubchannelRef{SubchannelId: id, Name: ref})
}
sc.SubchannelRef = subChans
sockets := make([]*channelzpb.SocketRef, 0, len(cm.Sockets))
for id, ref := range cm.Sockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
sc.SocketRef = sockets
sc.Data.Trace = channelTraceToProto(cm.Trace)
return sc
}
func securityToProto(se credentials.ChannelzSecurityValue) *channelzpb.Security {
switch v := se.(type) {
case *credentials.TLSChannelzSecurityValue:
return &channelzpb.Security{Model: &channelzpb.Security_Tls_{Tls: &channelzpb.Security_Tls{
CipherSuite: &channelzpb.Security_Tls_StandardName{StandardName: v.StandardName},
LocalCertificate: v.LocalCertificate,
RemoteCertificate: v.RemoteCertificate,
}}}
case *credentials.OtherChannelzSecurityValue:
otherSecurity := &channelzpb.Security_OtherSecurity{
Name: v.Name,
}
if anyval, err := ptypes.MarshalAny(v.Value); err == nil {
otherSecurity.Value = anyval
}
return &channelzpb.Security{Model: &channelzpb.Security_Other{Other: otherSecurity}}
}
return nil
}
func addrToProto(a net.Addr) *channelzpb.Address {
switch a.Network() {
case "udp":
// TODO: Address_OtherAddress{}. Need proto def for Value.
case "ip":
// Note zone info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.IPAddr).IP}}}
case "ip+net":
// Note mask info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.IPNet).IP}}}
case "tcp":
// Note zone info is discarded through the conversion.
return &channelzpb.Address{Address: &channelzpb.Address_TcpipAddress{TcpipAddress: &channelzpb.Address_TcpIpAddress{IpAddress: a.(*net.TCPAddr).IP, Port: int32(a.(*net.TCPAddr).Port)}}}
case "unix", "unixgram", "unixpacket":
return &channelzpb.Address{Address: &channelzpb.Address_UdsAddress_{UdsAddress: &channelzpb.Address_UdsAddress{Filename: a.String()}}}
default:
}
return &channelzpb.Address{}
}
func socketMetricToProto(sm *channelz.SocketMetric) *channelzpb.Socket {
s := &channelzpb.Socket{}
s.Ref = &channelzpb.SocketRef{SocketId: sm.ID, Name: sm.RefName}
s.Data = &channelzpb.SocketData{
StreamsStarted: sm.SocketData.StreamsStarted,
StreamsSucceeded: sm.SocketData.StreamsSucceeded,
StreamsFailed: sm.SocketData.StreamsFailed,
MessagesSent: sm.SocketData.MessagesSent,
MessagesReceived: sm.SocketData.MessagesReceived,
KeepAlivesSent: sm.SocketData.KeepAlivesSent,
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastLocalStreamCreatedTimestamp); err == nil {
s.Data.LastLocalStreamCreatedTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastRemoteStreamCreatedTimestamp); err == nil {
s.Data.LastRemoteStreamCreatedTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastMessageSentTimestamp); err == nil {
s.Data.LastMessageSentTimestamp = ts
}
if ts, err := ptypes.TimestampProto(sm.SocketData.LastMessageReceivedTimestamp); err == nil {
s.Data.LastMessageReceivedTimestamp = ts
}
s.Data.LocalFlowControlWindow = &wrpb.Int64Value{Value: sm.SocketData.LocalFlowControlWindow}
s.Data.RemoteFlowControlWindow = &wrpb.Int64Value{Value: sm.SocketData.RemoteFlowControlWindow}
if sm.SocketData.SocketOptions != nil {
s.Data.Option = sockoptToProto(sm.SocketData.SocketOptions)
}
if sm.SocketData.Security != nil {
s.Security = securityToProto(sm.SocketData.Security)
}
if sm.SocketData.LocalAddr != nil {
s.Local = addrToProto(sm.SocketData.LocalAddr)
}
if sm.SocketData.RemoteAddr != nil {
s.Remote = addrToProto(sm.SocketData.RemoteAddr)
}
s.RemoteName = sm.SocketData.RemoteName
return s
}
func (s *serverImpl) GetTopChannels(ctx context.Context, req *channelzpb.GetTopChannelsRequest) (*channelzpb.GetTopChannelsResponse, error) {
metrics, end := channelz.GetTopChannels(req.GetStartChannelId())
resp := &channelzpb.GetTopChannelsResponse{}
for _, m := range metrics {
resp.Channel = append(resp.Channel, channelMetricToProto(m))
}
resp.End = end
return resp, nil
}
func serverMetricToProto(sm *channelz.ServerMetric) *channelzpb.Server {
s := &channelzpb.Server{}
s.Ref = &channelzpb.ServerRef{ServerId: sm.ID, Name: sm.RefName}
s.Data = &channelzpb.ServerData{
CallsStarted: sm.ServerData.CallsStarted,
CallsSucceeded: sm.ServerData.CallsSucceeded,
CallsFailed: sm.ServerData.CallsFailed,
}
if ts, err := ptypes.TimestampProto(sm.ServerData.LastCallStartedTimestamp); err == nil {
s.Data.LastCallStartedTimestamp = ts
}
sockets := make([]*channelzpb.SocketRef, 0, len(sm.ListenSockets))
for id, ref := range sm.ListenSockets {
sockets = append(sockets, &channelzpb.SocketRef{SocketId: id, Name: ref})
}
s.ListenSocket = sockets
return s
}
func (s *serverImpl) GetServers(ctx context.Context, req *channelzpb.GetServersRequest) (*channelzpb.GetServersResponse, error) {
metrics, end := channelz.GetServers(req.GetStartServerId())
resp := &channelzpb.GetServersResponse{}
for _, m := range metrics {
resp.Server = append(resp.Server, serverMetricToProto(m))
}
resp.End = end
return resp, nil
}
func (s *serverImpl) GetServerSockets(ctx context.Context, req *channelzpb.GetServerSocketsRequest) (*channelzpb.GetServerSocketsResponse, error) {
metrics, end := channelz.GetServerSockets(req.GetServerId(), req.GetStartSocketId())
resp := &channelzpb.GetServerSocketsResponse{}
for _, m := range metrics {
resp.SocketRef = append(resp.SocketRef, &channelzpb.SocketRef{SocketId: m.ID, Name: m.RefName})
}
resp.End = end
return resp, nil
}
func (s *serverImpl) GetChannel(ctx context.Context, req *channelzpb.GetChannelRequest) (*channelzpb.GetChannelResponse, error) {
var metric *channelz.ChannelMetric
if metric = channelz.GetChannel(req.GetChannelId()); metric == nil {
return &channelzpb.GetChannelResponse{}, nil
}
resp := &channelzpb.GetChannelResponse{Channel: channelMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetSubchannel(ctx context.Context, req *channelzpb.GetSubchannelRequest) (*channelzpb.GetSubchannelResponse, error) {
var metric *channelz.SubChannelMetric
if metric = channelz.GetSubChannel(req.GetSubchannelId()); metric == nil {
return &channelzpb.GetSubchannelResponse{}, nil
}
resp := &channelzpb.GetSubchannelResponse{Subchannel: subChannelMetricToProto(metric)}
return resp, nil
}
func (s *serverImpl) GetSocket(ctx context.Context, req *channelzpb.GetSocketRequest) (*channelzpb.GetSocketResponse, error) {
var metric *channelz.SocketMetric
if metric = channelz.GetSocket(req.GetSocketId()); metric == nil {
return &channelzpb.GetSocketResponse{}, nil
}
resp := &channelzpb.GetSocketResponse{Socket: socketMetricToProto(metric)}
return resp, nil
}

View File

@@ -0,0 +1,152 @@
// +build linux,!appengine,go1.7
// +build 386 amd64
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// SocketOptions is only supported on linux system. The functions defined in
// this file are to parse the socket option field and the test is specifically
// to verify the behavior of socket option parsing.
package service
import (
"reflect"
"strconv"
"testing"
"github.com/golang/protobuf/ptypes"
durpb "github.com/golang/protobuf/ptypes/duration"
"golang.org/x/net/context"
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/internal/channelz"
)
func init() {
// Assign protoToSocketOption to protoToSocketOpt in order to enable socket option
// data conversion from proto message to channelz defined struct.
protoToSocketOpt = protoToSocketOption
}
func convertToDuration(d *durpb.Duration) (sec int64, usec int64) {
if d != nil {
if dur, err := ptypes.Duration(d); err == nil {
sec = int64(int64(dur) / 1e9)
usec = (int64(dur) - sec*1e9) / 1e3
}
}
return
}
func protoToLinger(protoLinger *channelzpb.SocketOptionLinger) *unix.Linger {
linger := &unix.Linger{}
if protoLinger.GetActive() {
linger.Onoff = 1
}
lv, _ := convertToDuration(protoLinger.GetDuration())
linger.Linger = int32(lv)
return linger
}
func protoToSocketOption(skopts []*channelzpb.SocketOption) *channelz.SocketOptionData {
skdata := &channelz.SocketOptionData{}
for _, opt := range skopts {
switch opt.GetName() {
case "SO_LINGER":
protoLinger := &channelzpb.SocketOptionLinger{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoLinger)
if err == nil {
skdata.Linger = protoToLinger(protoLinger)
}
case "SO_RCVTIMEO":
protoTimeout := &channelzpb.SocketOptionTimeout{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoTimeout)
if err == nil {
skdata.RecvTimeout = protoToTime(protoTimeout)
}
case "SO_SNDTIMEO":
protoTimeout := &channelzpb.SocketOptionTimeout{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), protoTimeout)
if err == nil {
skdata.SendTimeout = protoToTime(protoTimeout)
}
case "TCP_INFO":
tcpi := &channelzpb.SocketOptionTcpInfo{}
err := ptypes.UnmarshalAny(opt.GetAdditional(), tcpi)
if err == nil {
skdata.TCPInfo = &unix.TCPInfo{
State: uint8(tcpi.TcpiState),
Ca_state: uint8(tcpi.TcpiCaState),
Retransmits: uint8(tcpi.TcpiRetransmits),
Probes: uint8(tcpi.TcpiProbes),
Backoff: uint8(tcpi.TcpiBackoff),
Options: uint8(tcpi.TcpiOptions),
Rto: tcpi.TcpiRto,
Ato: tcpi.TcpiAto,
Snd_mss: tcpi.TcpiSndMss,
Rcv_mss: tcpi.TcpiRcvMss,
Unacked: tcpi.TcpiUnacked,
Sacked: tcpi.TcpiSacked,
Lost: tcpi.TcpiLost,
Retrans: tcpi.TcpiRetrans,
Fackets: tcpi.TcpiFackets,
Last_data_sent: tcpi.TcpiLastDataSent,
Last_ack_sent: tcpi.TcpiLastAckSent,
Last_data_recv: tcpi.TcpiLastDataRecv,
Last_ack_recv: tcpi.TcpiLastAckRecv,
Pmtu: tcpi.TcpiPmtu,
Rcv_ssthresh: tcpi.TcpiRcvSsthresh,
Rtt: tcpi.TcpiRtt,
Rttvar: tcpi.TcpiRttvar,
Snd_ssthresh: tcpi.TcpiSndSsthresh,
Snd_cwnd: tcpi.TcpiSndCwnd,
Advmss: tcpi.TcpiAdvmss,
Reordering: tcpi.TcpiReordering}
}
}
}
return skdata
}
func TestGetSocketOptions(t *testing.T) {
channelz.NewChannelzStorage()
ss := []*dummySocket{
{
socketOptions: &channelz.SocketOptionData{
Linger: &unix.Linger{Onoff: 1, Linger: 2},
RecvTimeout: &unix.Timeval{Sec: 10, Usec: 1},
SendTimeout: &unix.Timeval{},
TCPInfo: &unix.TCPInfo{State: 1},
},
},
}
svr := newCZServer()
ids := make([]int64, len(ss))
svrID := channelz.RegisterServer(&dummyServer{}, "")
for i, s := range ss {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
}
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
}
}
}

View File

@@ -0,0 +1,658 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"fmt"
"net"
"reflect"
"strconv"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"golang.org/x/net/context"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
)
func init() {
channelz.TurnOn()
}
type protoToSocketOptFunc func([]*channelzpb.SocketOption) *channelz.SocketOptionData
// protoToSocketOpt is used in function socketProtoToStruct to extract socket option
// data from unmarshaled proto message.
// It is only defined under linux, non-appengine environment on x86 architecture.
var protoToSocketOpt protoToSocketOptFunc
// emptyTime is used for detecting unset value of time.Time type.
// For go1.7 and earlier, ptypes.Timestamp will fill in the loc field of time.Time
// with &utcLoc. However zero value of a time.Time type value loc field is nil.
// This behavior will make reflect.DeepEqual fail upon unset time.Time field,
// and cause false positive fatal error.
var emptyTime time.Time
type dummyChannel struct {
state connectivity.State
target string
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTimestamp time.Time
}
func (d *dummyChannel) ChannelzMetric() *channelz.ChannelInternalMetric {
return &channelz.ChannelInternalMetric{
State: d.state,
Target: d.target,
CallsStarted: d.callsStarted,
CallsSucceeded: d.callsSucceeded,
CallsFailed: d.callsFailed,
LastCallStartedTimestamp: d.lastCallStartedTimestamp,
}
}
type dummyServer struct {
callsStarted int64
callsSucceeded int64
callsFailed int64
lastCallStartedTimestamp time.Time
}
func (d *dummyServer) ChannelzMetric() *channelz.ServerInternalMetric {
return &channelz.ServerInternalMetric{
CallsStarted: d.callsStarted,
CallsSucceeded: d.callsSucceeded,
CallsFailed: d.callsFailed,
LastCallStartedTimestamp: d.lastCallStartedTimestamp,
}
}
type dummySocket struct {
streamsStarted int64
streamsSucceeded int64
streamsFailed int64
messagesSent int64
messagesReceived int64
keepAlivesSent int64
lastLocalStreamCreatedTimestamp time.Time
lastRemoteStreamCreatedTimestamp time.Time
lastMessageSentTimestamp time.Time
lastMessageReceivedTimestamp time.Time
localFlowControlWindow int64
remoteFlowControlWindow int64
socketOptions *channelz.SocketOptionData
localAddr net.Addr
remoteAddr net.Addr
security credentials.ChannelzSecurityValue
remoteName string
}
func (d *dummySocket) ChannelzMetric() *channelz.SocketInternalMetric {
return &channelz.SocketInternalMetric{
StreamsStarted: d.streamsStarted,
StreamsSucceeded: d.streamsSucceeded,
StreamsFailed: d.streamsFailed,
MessagesSent: d.messagesSent,
MessagesReceived: d.messagesReceived,
KeepAlivesSent: d.keepAlivesSent,
LastLocalStreamCreatedTimestamp: d.lastLocalStreamCreatedTimestamp,
LastRemoteStreamCreatedTimestamp: d.lastRemoteStreamCreatedTimestamp,
LastMessageSentTimestamp: d.lastMessageSentTimestamp,
LastMessageReceivedTimestamp: d.lastMessageReceivedTimestamp,
LocalFlowControlWindow: d.localFlowControlWindow,
RemoteFlowControlWindow: d.remoteFlowControlWindow,
SocketOptions: d.socketOptions,
LocalAddr: d.localAddr,
RemoteAddr: d.remoteAddr,
Security: d.security,
RemoteName: d.remoteName,
}
}
func channelProtoToStruct(c *channelzpb.Channel) *dummyChannel {
dc := &dummyChannel{}
pdata := c.GetData()
switch pdata.GetState().GetState() {
case channelzpb.ChannelConnectivityState_UNKNOWN:
// TODO: what should we set here?
case channelzpb.ChannelConnectivityState_IDLE:
dc.state = connectivity.Idle
case channelzpb.ChannelConnectivityState_CONNECTING:
dc.state = connectivity.Connecting
case channelzpb.ChannelConnectivityState_READY:
dc.state = connectivity.Ready
case channelzpb.ChannelConnectivityState_TRANSIENT_FAILURE:
dc.state = connectivity.TransientFailure
case channelzpb.ChannelConnectivityState_SHUTDOWN:
dc.state = connectivity.Shutdown
}
dc.target = pdata.GetTarget()
dc.callsStarted = pdata.CallsStarted
dc.callsSucceeded = pdata.CallsSucceeded
dc.callsFailed = pdata.CallsFailed
if t, err := ptypes.Timestamp(pdata.GetLastCallStartedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
dc.lastCallStartedTimestamp = t
}
}
return dc
}
func serverProtoToStruct(s *channelzpb.Server) *dummyServer {
ds := &dummyServer{}
pdata := s.GetData()
ds.callsStarted = pdata.CallsStarted
ds.callsSucceeded = pdata.CallsSucceeded
ds.callsFailed = pdata.CallsFailed
if t, err := ptypes.Timestamp(pdata.GetLastCallStartedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastCallStartedTimestamp = t
}
}
return ds
}
func socketProtoToStruct(s *channelzpb.Socket) *dummySocket {
ds := &dummySocket{}
pdata := s.GetData()
ds.streamsStarted = pdata.GetStreamsStarted()
ds.streamsSucceeded = pdata.GetStreamsSucceeded()
ds.streamsFailed = pdata.GetStreamsFailed()
ds.messagesSent = pdata.GetMessagesSent()
ds.messagesReceived = pdata.GetMessagesReceived()
ds.keepAlivesSent = pdata.GetKeepAlivesSent()
if t, err := ptypes.Timestamp(pdata.GetLastLocalStreamCreatedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastLocalStreamCreatedTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastRemoteStreamCreatedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastRemoteStreamCreatedTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastMessageSentTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastMessageSentTimestamp = t
}
}
if t, err := ptypes.Timestamp(pdata.GetLastMessageReceivedTimestamp()); err == nil {
if !t.Equal(emptyTime) {
ds.lastMessageReceivedTimestamp = t
}
}
if v := pdata.GetLocalFlowControlWindow(); v != nil {
ds.localFlowControlWindow = v.Value
}
if v := pdata.GetRemoteFlowControlWindow(); v != nil {
ds.remoteFlowControlWindow = v.Value
}
if v := pdata.GetOption(); v != nil && protoToSocketOpt != nil {
ds.socketOptions = protoToSocketOpt(v)
}
if v := s.GetSecurity(); v != nil {
ds.security = protoToSecurity(v)
}
if local := s.GetLocal(); local != nil {
ds.localAddr = protoToAddr(local)
}
if remote := s.GetRemote(); remote != nil {
ds.remoteAddr = protoToAddr(remote)
}
ds.remoteName = s.GetRemoteName()
return ds
}
func protoToSecurity(protoSecurity *channelzpb.Security) credentials.ChannelzSecurityValue {
switch v := protoSecurity.Model.(type) {
case *channelzpb.Security_Tls_:
return &credentials.TLSChannelzSecurityValue{StandardName: v.Tls.GetStandardName(), LocalCertificate: v.Tls.GetLocalCertificate(), RemoteCertificate: v.Tls.GetRemoteCertificate()}
case *channelzpb.Security_Other:
sv := &credentials.OtherChannelzSecurityValue{Name: v.Other.GetName()}
var x ptypes.DynamicAny
if err := ptypes.UnmarshalAny(v.Other.GetValue(), &x); err == nil {
sv.Value = x.Message
}
return sv
}
return nil
}
func protoToAddr(a *channelzpb.Address) net.Addr {
switch v := a.Address.(type) {
case *channelzpb.Address_TcpipAddress:
if port := v.TcpipAddress.GetPort(); port != 0 {
return &net.TCPAddr{IP: v.TcpipAddress.GetIpAddress(), Port: int(port)}
}
return &net.IPAddr{IP: v.TcpipAddress.GetIpAddress()}
case *channelzpb.Address_UdsAddress_:
return &net.UnixAddr{Name: v.UdsAddress.GetFilename(), Net: "unix"}
case *channelzpb.Address_OtherAddress_:
// TODO:
}
return nil
}
func convertSocketRefSliceToMap(sktRefs []*channelzpb.SocketRef) map[int64]string {
m := make(map[int64]string)
for _, sr := range sktRefs {
m[sr.SocketId] = sr.Name
}
return m
}
type OtherSecurityValue struct {
LocalCertificate []byte `protobuf:"bytes,1,opt,name=local_certificate,json=localCertificate,proto3" json:"local_certificate,omitempty"`
RemoteCertificate []byte `protobuf:"bytes,2,opt,name=remote_certificate,json=remoteCertificate,proto3" json:"remote_certificate,omitempty"`
}
func (m *OtherSecurityValue) Reset() { *m = OtherSecurityValue{} }
func (m *OtherSecurityValue) String() string { return proto.CompactTextString(m) }
func (*OtherSecurityValue) ProtoMessage() {}
func init() {
// Ad-hoc registering the proto type here to facilitate UnmarshalAny of OtherSecurityValue.
proto.RegisterType((*OtherSecurityValue)(nil), "grpc.credentials.OtherChannelzSecurityValue")
}
func TestGetTopChannels(t *testing.T) {
tcs := []*dummyChannel{
{
state: connectivity.Connecting,
target: "test.channelz:1234",
callsStarted: 6,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
state: connectivity.Connecting,
target: "test.channelz:1234",
callsStarted: 1,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
state: connectivity.Shutdown,
target: "test.channelz:8888",
callsStarted: 0,
callsSucceeded: 0,
callsFailed: 0,
},
{},
}
channelz.NewChannelzStorage()
for _, c := range tcs {
channelz.RegisterChannel(c, 0, "")
}
s := newCZServer()
resp, _ := s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
for i, c := range resp.GetChannel() {
if !reflect.DeepEqual(channelProtoToStruct(c), tcs[i]) {
t.Fatalf("dummyChannel: %d, want: %#v, got: %#v", i, tcs[i], channelProtoToStruct(c))
}
}
for i := 0; i < 50; i++ {
channelz.RegisterChannel(tcs[0], 0, "")
}
resp, _ = s.GetTopChannels(context.Background(), &channelzpb.GetTopChannelsRequest{StartChannelId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetServers(t *testing.T) {
ss := []*dummyServer{
{
callsStarted: 6,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
callsStarted: 1,
callsSucceeded: 2,
callsFailed: 3,
lastCallStartedTimestamp: time.Now().UTC(),
},
{
callsStarted: 1,
callsSucceeded: 0,
callsFailed: 0,
lastCallStartedTimestamp: time.Now().UTC(),
},
}
channelz.NewChannelzStorage()
for _, s := range ss {
channelz.RegisterServer(s, "")
}
svr := newCZServer()
resp, _ := svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want true, got %v", resp.GetEnd())
}
for i, s := range resp.GetServer() {
if !reflect.DeepEqual(serverProtoToStruct(s), ss[i]) {
t.Fatalf("dummyServer: %d, want: %#v, got: %#v", i, ss[i], serverProtoToStruct(s))
}
}
for i := 0; i < 50; i++ {
channelz.RegisterServer(ss[0], "")
}
resp, _ = svr.GetServers(context.Background(), &channelzpb.GetServersRequest{StartServerId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetServerSockets(t *testing.T) {
channelz.NewChannelzStorage()
svrID := channelz.RegisterServer(&dummyServer{}, "")
refNames := []string{"listen socket 1", "normal socket 1", "normal socket 2"}
ids := make([]int64, 3)
ids[0] = channelz.RegisterListenSocket(&dummySocket{}, svrID, refNames[0])
ids[1] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[1])
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, svrID, refNames[2])
svr := newCZServer()
resp, _ := svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if !resp.GetEnd() {
t.Fatalf("resp.GetEnd() want: true, got: %v", resp.GetEnd())
}
// GetServerSockets only return normal sockets.
want := map[int64]string{
ids[1]: refNames[1],
ids[2]: refNames[2],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(resp.GetSocketRef()), want) {
t.Fatalf("GetServerSockets want: %#v, got: %#v", want, resp.GetSocketRef())
}
for i := 0; i < 50; i++ {
channelz.RegisterNormalSocket(&dummySocket{}, svrID, "")
}
resp, _ = svr.GetServerSockets(context.Background(), &channelzpb.GetServerSocketsRequest{ServerId: svrID, StartSocketId: 0})
if resp.GetEnd() {
t.Fatalf("resp.GetEnd() want false, got %v", resp.GetEnd())
}
}
func TestGetChannel(t *testing.T) {
channelz.NewChannelzStorage()
refNames := []string{"top channel 1", "nested channel 1", "sub channel 2", "nested channel 3"}
ids := make([]int64, 4)
ids[0] = channelz.RegisterChannel(&dummyChannel{}, 0, refNames[0])
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
})
ids[1] = channelz.RegisterChannel(&dummyChannel{}, ids[0], refNames[1])
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[1]),
Severity: channelz.CtINFO,
},
})
ids[2] = channelz.RegisterSubChannel(&dummyChannel{}, ids[0], refNames[2])
channelz.AddTraceEvent(ids[2], &channelz.TraceEventDesc{
Desc: "SubChannel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("SubChannel(id:%d) created", ids[2]),
Severity: channelz.CtINFO,
},
})
ids[3] = channelz.RegisterChannel(&dummyChannel{}, ids[1], refNames[3])
channelz.AddTraceEvent(ids[3], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[3]),
Severity: channelz.CtINFO,
},
})
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Channel Connectivity change to %v", connectivity.Ready),
Severity: channelz.CtINFO,
})
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: "Resolver returns an empty address list",
Severity: channelz.CtWarning,
})
svr := newCZServer()
resp, _ := svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[0]})
metrics := resp.GetChannel()
subChans := metrics.GetSubchannelRef()
if len(subChans) != 1 || subChans[0].GetName() != refNames[2] || subChans[0].GetSubchannelId() != ids[2] {
t.Fatalf("metrics.GetSubChannelRef() want %#v, got %#v", []*channelzpb.SubchannelRef{{SubchannelId: ids[2], Name: refNames[2]}}, subChans)
}
nestedChans := metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[1] || nestedChans[0].GetChannelId() != ids[1] {
t.Fatalf("metrics.GetChannelRef() want %#v, got %#v", []*channelzpb.ChannelRef{{ChannelId: ids[1], Name: refNames[1]}}, nestedChans)
}
trace := metrics.GetData().GetTrace()
want := []struct {
desc string
severity channelzpb.ChannelTraceEvent_Severity
childID int64
childRef string
}{
{desc: "Channel Created", severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[1]), severity: channelzpb.ChannelTraceEvent_CT_INFO, childID: ids[1], childRef: refNames[1]},
{desc: fmt.Sprintf("SubChannel(id:%d) created", ids[2]), severity: channelzpb.ChannelTraceEvent_CT_INFO, childID: ids[2], childRef: refNames[2]},
{desc: fmt.Sprintf("Channel Connectivity change to %v", connectivity.Ready), severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: "Resolver returns an empty address list", severity: channelzpb.ChannelTraceEvent_CT_WARNING},
}
for i, e := range trace.Events {
if e.GetDescription() != want[i].desc {
t.Fatalf("trace: GetDescription want %#v, got %#v", want[i].desc, e.GetDescription())
}
if e.GetSeverity() != want[i].severity {
t.Fatalf("trace: GetSeverity want %#v, got %#v", want[i].severity, e.GetSeverity())
}
if want[i].childID == 0 && (e.GetChannelRef() != nil || e.GetSubchannelRef() != nil) {
t.Fatalf("trace: GetChannelRef() should return nil, as there is no reference")
}
if e.GetChannelRef().GetChannelId() != want[i].childID || e.GetChannelRef().GetName() != want[i].childRef {
if e.GetSubchannelRef().GetSubchannelId() != want[i].childID || e.GetSubchannelRef().GetName() != want[i].childRef {
t.Fatalf("trace: GetChannelRef/GetSubchannelRef want (child ID: %d, child name: %q), got %#v and %#v", want[i].childID, want[i].childRef, e.GetChannelRef(), e.GetSubchannelRef())
}
}
}
resp, _ = svr.GetChannel(context.Background(), &channelzpb.GetChannelRequest{ChannelId: ids[1]})
metrics = resp.GetChannel()
nestedChans = metrics.GetChannelRef()
if len(nestedChans) != 1 || nestedChans[0].GetName() != refNames[3] || nestedChans[0].GetChannelId() != ids[3] {
t.Fatalf("metrics.GetChannelRef() want %#v, got %#v", []*channelzpb.ChannelRef{{ChannelId: ids[3], Name: refNames[3]}}, nestedChans)
}
}
func TestGetSubChannel(t *testing.T) {
var (
subchanCreated = "SubChannel Created"
subchanConnectivityChange = fmt.Sprintf("Subchannel Connectivity change to %v", connectivity.Ready)
subChanPickNewAddress = fmt.Sprintf("Subchannel picks a new address %q to connect", "0.0.0.0")
)
channelz.NewChannelzStorage()
refNames := []string{"top channel 1", "sub channel 1", "socket 1", "socket 2"}
ids := make([]int64, 4)
ids[0] = channelz.RegisterChannel(&dummyChannel{}, 0, refNames[0])
channelz.AddTraceEvent(ids[0], &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
})
ids[1] = channelz.RegisterSubChannel(&dummyChannel{}, ids[0], refNames[1])
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: subchanCreated,
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", ids[0]),
Severity: channelz.CtINFO,
},
})
ids[2] = channelz.RegisterNormalSocket(&dummySocket{}, ids[1], refNames[2])
ids[3] = channelz.RegisterNormalSocket(&dummySocket{}, ids[1], refNames[3])
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: subchanConnectivityChange,
Severity: channelz.CtINFO,
})
channelz.AddTraceEvent(ids[1], &channelz.TraceEventDesc{
Desc: subChanPickNewAddress,
Severity: channelz.CtINFO,
})
svr := newCZServer()
resp, _ := svr.GetSubchannel(context.Background(), &channelzpb.GetSubchannelRequest{SubchannelId: ids[1]})
metrics := resp.GetSubchannel()
want := map[int64]string{
ids[2]: refNames[2],
ids[3]: refNames[3],
}
if !reflect.DeepEqual(convertSocketRefSliceToMap(metrics.GetSocketRef()), want) {
t.Fatalf("metrics.GetSocketRef() want %#v: got: %#v", want, metrics.GetSocketRef())
}
trace := metrics.GetData().GetTrace()
wantTrace := []struct {
desc string
severity channelzpb.ChannelTraceEvent_Severity
childID int64
childRef string
}{
{desc: subchanCreated, severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: subchanConnectivityChange, severity: channelzpb.ChannelTraceEvent_CT_INFO},
{desc: subChanPickNewAddress, severity: channelzpb.ChannelTraceEvent_CT_INFO},
}
for i, e := range trace.Events {
if e.GetDescription() != wantTrace[i].desc {
t.Fatalf("trace: GetDescription want %#v, got %#v", wantTrace[i].desc, e.GetDescription())
}
if e.GetSeverity() != wantTrace[i].severity {
t.Fatalf("trace: GetSeverity want %#v, got %#v", wantTrace[i].severity, e.GetSeverity())
}
if wantTrace[i].childID == 0 && (e.GetChannelRef() != nil || e.GetSubchannelRef() != nil) {
t.Fatalf("trace: GetChannelRef() should return nil, as there is no reference")
}
if e.GetChannelRef().GetChannelId() != wantTrace[i].childID || e.GetChannelRef().GetName() != wantTrace[i].childRef {
if e.GetSubchannelRef().GetSubchannelId() != wantTrace[i].childID || e.GetSubchannelRef().GetName() != wantTrace[i].childRef {
t.Fatalf("trace: GetChannelRef/GetSubchannelRef want (child ID: %d, child name: %q), got %#v and %#v", wantTrace[i].childID, wantTrace[i].childRef, e.GetChannelRef(), e.GetSubchannelRef())
}
}
}
}
func TestGetSocket(t *testing.T) {
channelz.NewChannelzStorage()
ss := []*dummySocket{
{
streamsStarted: 10,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastLocalStreamCreatedTimestamp: time.Now().UTC(),
lastRemoteStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 1024,
localAddr: &net.TCPAddr{IP: net.ParseIP("1.0.0.1"), Port: 10001},
remoteAddr: &net.TCPAddr{IP: net.ParseIP("12.0.0.1"), Port: 10002},
remoteName: "remote.remote",
},
{
streamsStarted: 10,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastRemoteStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 1024,
localAddr: &net.UnixAddr{Name: "file.path", Net: "unix"},
remoteAddr: &net.UnixAddr{Name: "another.path", Net: "unix"},
remoteName: "remote.remote",
},
{
streamsStarted: 5,
streamsSucceeded: 2,
streamsFailed: 3,
messagesSent: 20,
messagesReceived: 10,
keepAlivesSent: 2,
lastLocalStreamCreatedTimestamp: time.Now().UTC(),
lastMessageSentTimestamp: time.Now().UTC(),
lastMessageReceivedTimestamp: time.Now().UTC(),
localFlowControlWindow: 65536,
remoteFlowControlWindow: 10240,
localAddr: &net.IPAddr{IP: net.ParseIP("1.0.0.1")},
remoteAddr: &net.IPAddr{IP: net.ParseIP("9.0.0.1")},
remoteName: "",
},
{
localAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 10001},
},
{
security: &credentials.TLSChannelzSecurityValue{
StandardName: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
RemoteCertificate: []byte{48, 130, 2, 156, 48, 130, 2, 5, 160},
},
},
{
security: &credentials.OtherChannelzSecurityValue{
Name: "XXXX",
},
},
{
security: &credentials.OtherChannelzSecurityValue{
Name: "YYYY",
Value: &OtherSecurityValue{LocalCertificate: []byte{1, 2, 3}, RemoteCertificate: []byte{4, 5, 6}},
},
},
}
svr := newCZServer()
ids := make([]int64, len(ss))
svrID := channelz.RegisterServer(&dummyServer{}, "")
for i, s := range ss {
ids[i] = channelz.RegisterNormalSocket(s, svrID, strconv.Itoa(i))
}
for i, s := range ss {
resp, _ := svr.GetSocket(context.Background(), &channelzpb.GetSocketRequest{SocketId: ids[i]})
metrics := resp.GetSocket()
if !reflect.DeepEqual(metrics.GetRef(), &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}) || !reflect.DeepEqual(socketProtoToStruct(metrics), s) {
t.Fatalf("resp.GetSocket() want: metrics.GetRef() = %#v and %#v, got: metrics.GetRef() = %#v and %#v", &channelzpb.SocketRef{SocketId: ids[i], Name: strconv.Itoa(i)}, s, metrics.GetRef(), socketProtoToStruct(metrics))
}
}
}

View File

@@ -0,0 +1,33 @@
// +build 386,linux,!appengine,go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
)
func protoToTime(protoTime *channelzpb.SocketOptionTimeout) *unix.Timeval {
timeout := &unix.Timeval{}
sec, usec := convertToDuration(protoTime.GetDuration())
timeout.Sec, timeout.Usec = int32(sec), int32(usec)
return timeout
}

View File

@@ -0,0 +1,32 @@
// +build amd64,linux,!appengine,go1.7
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"golang.org/x/sys/unix"
channelzpb "google.golang.org/grpc/channelz/grpc_channelz_v1"
)
func protoToTime(protoTime *channelzpb.SocketOptionTimeout) *unix.Timeval {
timeout := &unix.Timeval{}
timeout.Sec, timeout.Usec = convertToDuration(protoTime.GetDuration())
return timeout
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,515 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const stateRecordingBalancerName = "state_recoding_balancer"
var testBalancer = &stateRecordingBalancer{}
func init() {
balancer.Register(testBalancer)
}
// These tests use a pipeListener. This listener is similar to net.Listener except that it is unbuffered, so each read
// and write will wait for the other side's corresponding write or read.
func TestStateTransitions_SingleAddress(t *testing.T) {
defer leakcheck.Check(t)
mctBkp := getMinConnectTimeout()
defer func() {
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
}()
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*100)
for _, test := range []struct {
desc string
want []connectivity.State
server func(net.Listener) net.Conn
}{
{
desc: "When the server returns server preface, the client enters READY.",
want: []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return nil
}
return conn
},
},
{
desc: "When the connection is closed, the client enters TRANSIENT FAILURE.",
want: []connectivity.State{
connectivity.Connecting,
connectivity.TransientFailure,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
conn.Close()
return nil
},
},
{
desc: `When the server sends its connection preface, but the connection dies before the client can write its
connection preface, the client enters TRANSIENT FAILURE.`,
want: []connectivity.State{
connectivity.Connecting,
connectivity.TransientFailure,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return nil
}
conn.Close()
return nil
},
},
{
desc: `When the server reads the client connection preface but does not send its connection preface, the
client enters TRANSIENT FAILURE.`,
want: []connectivity.State{
connectivity.Connecting,
connectivity.TransientFailure,
},
server: func(lis net.Listener) net.Conn {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return nil
}
go keepReading(conn)
return conn
},
},
} {
t.Log(test.desc)
testStateTransitionSingleAddress(t, test.want, test.server)
}
}
func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener) net.Conn) {
defer leakcheck.Check(t)
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
pl := testutils.NewPipeListener()
defer pl.Close()
// Launch the server.
var conn net.Conn
var connMu sync.Mutex
go func() {
connMu.Lock()
conn = server(pl)
connMu.Unlock()
}()
client, err := DialContext(ctx, "", WithWaitForHandshake(), WithInsecure(),
WithBalancerName(stateRecordingBalancerName), WithDialer(pl.Dialer()), withBackoff(noBackoff{}))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(5 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
connMu.Lock()
defer connMu.Unlock()
if conn != nil {
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}
}
// When a READY connection is closed, the client enters TRANSIENT FAILURE before CONNECTING.
func TestStateTransition_ReadyToTransientFailure(t *testing.T) {
defer leakcheck.Check(t)
want := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.TransientFailure,
connectivity.Connecting,
}
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis.Close()
sawReady := make(chan struct{})
// Launch the server.
go func() {
conn, err := lis.Accept()
if err != nil {
t.Error(err)
return
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return
}
// Prevents race between onPrefaceReceipt and onClose.
<-sawReady
conn.Close()
}()
client, err := DialContext(ctx, lis.Addr().String(), WithWaitForHandshake(), WithInsecure(), WithBalancerName(stateRecordingBalancerName))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(5 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen == connectivity.Ready {
close(sawReady)
}
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
}
// When the first connection is closed, the client enters stays in CONNECTING until it tries the second
// address (which succeeds, and then it enters READY).
func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {
defer leakcheck.Check(t)
want := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
lis1, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis1.Close()
lis2, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis2.Close()
server1Done := make(chan struct{})
server2Done := make(chan struct{})
// Launch server 1.
go func() {
conn, err := lis1.Accept()
if err != nil {
t.Error(err)
return
}
conn.Close()
close(server1Done)
}()
// Launch server 2.
go func() {
conn, err := lis2.Accept()
if err != nil {
t.Error(err)
return
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return
}
close(server2Done)
}()
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialAddrs([]resolver.Address{
{Addr: lis1.Addr().String()},
{Addr: lis2.Addr().String()},
})
client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(5 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
select {
case <-timeout:
t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 1")
case <-server1Done:
}
select {
case <-timeout:
t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 2")
case <-server2Done:
}
}
// When there are multiple addresses, and we enter READY on one of them, a later closure should cause
// the client to enter TRANSIENT FAILURE before it re-enters CONNECTING.
func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
defer leakcheck.Check(t)
want := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.TransientFailure,
connectivity.Connecting,
}
stateNotifications := make(chan connectivity.State, len(want))
testBalancer.ResetNotifier(stateNotifications)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
lis1, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis1.Close()
// Never actually gets used; we just want it to be alive so that the resolver has two addresses to target.
lis2, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis2.Close()
server1Done := make(chan struct{})
sawReady := make(chan struct{})
// Launch server 1.
go func() {
conn, err := lis1.Accept()
if err != nil {
t.Error(err)
return
}
go keepReading(conn)
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings frame. %v", err)
return
}
<-sawReady
conn.Close()
_, err = lis1.Accept()
if err != nil {
t.Error(err)
return
}
close(server1Done)
}()
rb := manual.NewBuilderWithScheme("whatever")
rb.InitialAddrs([]resolver.Address{
{Addr: lis1.Addr().String()},
{Addr: lis2.Addr().String()},
})
client, err := DialContext(ctx, "this-gets-overwritten", WithInsecure(), WithWaitForHandshake(), WithBalancerName(stateRecordingBalancerName), withResolverBuilder(rb))
if err != nil {
t.Fatal(err)
}
defer client.Close()
timeout := time.After(2 * time.Second)
for i := 0; i < len(want); i++ {
select {
case <-timeout:
t.Fatalf("timed out waiting for state %d (%v) in flow %v", i, want[i], want)
case seen := <-stateNotifications:
if seen == connectivity.Ready {
close(sawReady)
}
if seen != want[i] {
t.Fatalf("expected to see %v at position %d in flow %v, got %v", want[i], i, want, seen)
}
}
}
select {
case <-timeout:
t.Fatal("saw the correct state transitions, but timed out waiting for client to finish interactions with server 1")
case <-server1Done:
}
}
type stateRecordingBalancer struct {
mu sync.Mutex
notifier chan<- connectivity.State
balancer.Balancer
}
func (b *stateRecordingBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
b.mu.Lock()
b.notifier <- s
b.mu.Unlock()
b.Balancer.HandleSubConnStateChange(sc, s)
}
func (b *stateRecordingBalancer) ResetNotifier(r chan<- connectivity.State) {
b.mu.Lock()
defer b.mu.Unlock()
b.notifier = r
}
func (b *stateRecordingBalancer) Close() {
b.mu.Lock()
u := b.Balancer
b.mu.Unlock()
u.Close()
}
func (b *stateRecordingBalancer) Name() string {
return stateRecordingBalancerName
}
func (b *stateRecordingBalancer) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
b.mu.Lock()
b.Balancer = balancer.Get(PickFirstBalancerName).Build(cc, opts)
b.mu.Unlock()
return b
}
type noBackoff struct{}
func (b noBackoff) Backoff(int) time.Duration { return time.Duration(0) }
// Keep reading until something causes the connection to die (EOF, server closed, etc). Useful
// as a tool for mindlessly keeping the connection healthy, since the client will error if
// things like client prefaces are not accepted in a timely fashion.
func keepReading(conn net.Conn) {
buf := make([]byte, 1024)
for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) {
}
}

View File

@@ -19,26 +19,39 @@
package grpc
import (
"io"
"errors"
"fmt"
"math"
"net"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
_ "google.golang.org/grpc/resolver/passthrough"
"google.golang.org/grpc/test/leakcheck"
"google.golang.org/grpc/testdata"
)
var (
mutableMinConnectTimeout = time.Second * 20
)
func init() {
getMinConnectTimeout = func() time.Duration {
return time.Duration(atomic.LoadInt64((*int64)(&mutableMinConnectTimeout)))
}
}
func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.State, bool) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -121,11 +134,11 @@ func TestDialWithMultipleBackendsNotSendingServerPreface(t *testing.T) {
func TestDialWaitsForServerSettings(t *testing.T) {
defer leakcheck.Check(t)
server, err := net.Listen("tcp", "localhost:0")
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer server.Close()
defer lis.Close()
done := make(chan struct{})
sent := make(chan struct{})
dialDone := make(chan struct{})
@@ -133,15 +146,15 @@ func TestDialWaitsForServerSettings(t *testing.T) {
defer func() {
close(done)
}()
conn, err := server.Accept()
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
defer conn.Close()
// Sleep so that if the test were to fail it
// will fail more often than not.
time.Sleep(100 * time.Millisecond)
// Sleep for a little bit to make sure that Dial on client
// side blocks until settings are received.
time.Sleep(500 * time.Millisecond)
framer := http2.NewFramer(conn, conn)
close(sent)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
@@ -150,12 +163,11 @@ func TestDialWaitsForServerSettings(t *testing.T) {
}
<-dialDone // Close conn only after dial returns.
}()
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, err := DialContext(ctx, server.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock())
client, err := DialContext(ctx, lis.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock())
close(dialDone)
if err != nil {
cancel()
t.Fatalf("Error while dialing. Err: %v", err)
}
defer client.Close()
@@ -168,118 +180,140 @@ func TestDialWaitsForServerSettings(t *testing.T) {
}
func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
mctBkp := minConnectTimeout
// Call this only after transportMonitor goroutine has ended.
defer func() {
minConnectTimeout = mctBkp
}()
func TestDialWaitsForServerSettingsAndFails(t *testing.T) {
defer leakcheck.Check(t)
minConnectTimeout = time.Millisecond * 500
server, err := net.Listen("tcp", "localhost:0")
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer server.Close()
done := make(chan struct{})
clientDone := make(chan struct{})
numConns := 0
go func() { // Launch the server.
defer func() {
if done != nil {
close(done)
}
close(done)
}()
conn1, err := server.Accept()
for {
conn, err := lis.Accept()
if err != nil {
break
}
numConns++
defer conn.Close()
}
}()
getMinConnectTimeout = func() time.Duration { return time.Second / 2 }
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
client, err := DialContext(ctx, lis.Addr().String(), WithInsecure(), WithWaitForHandshake(), WithBlock())
lis.Close()
if err == nil {
client.Close()
t.Fatalf("Unexpected success (err=nil) while dialing")
}
if err != context.DeadlineExceeded {
t.Fatalf("DialContext(_) = %v; want context.DeadlineExceeded", err)
}
if numConns < 2 {
t.Fatalf("dial attempts: %v; want > 1", numConns)
}
<-done
}
func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
// 1. Client connects to a server that doesn't send preface.
// 2. After minConnectTimeout(500 ms here), client disconnects and retries.
// 3. The new server sends its preface.
// 4. Client doesn't kill the connection this time.
mctBkp := getMinConnectTimeout()
defer func() {
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
}()
defer leakcheck.Check(t)
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*500)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
var (
conn2 net.Conn
over uint32
)
defer func() {
lis.Close()
// conn2 shouldn't be closed until the client has
// observed a successful test.
if conn2 != nil {
conn2.Close()
}
}()
done := make(chan struct{})
accepted := make(chan struct{})
go func() { // Launch the server.
defer close(done)
conn1, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
defer conn1.Close()
// Don't send server settings and make sure the connection is closed.
time.Sleep(time.Millisecond * 1500) // Since the first backoff is for a second.
conn1.SetDeadline(time.Now().Add(time.Second))
b := make([]byte, 24)
for {
// Make sure the connection was closed by client.
_, err = conn1.Read(b)
if err == nil {
continue
}
if err != io.EOF {
t.Errorf(" conn1.Read(_) = _, %v, want _, io.EOF", err)
return
}
break
}
conn2, err := server.Accept() // Accept a reconnection request from client.
// Don't send server settings and the client should close the connection and try again.
conn2, err = lis.Accept() // Accept a reconnection request from client.
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
defer conn2.Close()
close(accepted)
framer := http2.NewFramer(conn2, conn2)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
if err = framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings. Err: %v", err)
return
}
time.Sleep(time.Millisecond * 1500) // Since the first backoff is for a second.
conn2.SetDeadline(time.Now().Add(time.Millisecond * 500))
b := make([]byte, 8)
for {
// Make sure the connection stays open and is closed
// only by connection timeout.
_, err = conn2.Read(b)
if err == nil {
continue
}
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
if atomic.LoadUint32(&over) == 1 {
// The connection stayed alive for the timer.
// Success.
return
}
t.Errorf("Unexpected error while reading. Err: %v, want timeout error", err)
break
}
close(done)
done = nil
<-clientDone
}()
client, err := Dial(server.Addr().String(), WithInsecure())
client, err := Dial(lis.Addr().String(), WithInsecure())
if err != nil {
t.Fatalf("Error while dialing. Err: %v", err)
}
<-done
// TODO: The code from BEGIN to END should be delete once issue
// https://github.com/grpc/grpc-go/issues/1750 is fixed.
// BEGIN
// Set underlying addrConns state to Shutdown so that no reconnect
// attempts take place and thereby resetting minConnectTimeout is
// race free.
client.mu.Lock()
addrConns := client.conns
client.mu.Unlock()
for ac := range addrConns {
ac.mu.Lock()
ac.state = connectivity.Shutdown
ac.mu.Unlock()
// wait for connection to be accepted on the server.
timer := time.NewTimer(time.Second * 10)
select {
case <-accepted:
case <-timer.C:
t.Fatalf("Client didn't make another connection request in time.")
}
// END
// Make sure the connection stays alive for sometime.
time.Sleep(time.Second * 2)
atomic.StoreUint32(&over, 1)
client.Close()
close(clientDone)
<-done
}
func TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
defer leakcheck.Check(t)
server, err := net.Listen("tcp", "localhost:0")
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer server.Close()
defer lis.Close()
done := make(chan struct{})
go func() { // Launch the server.
defer func() {
close(done)
}()
conn, err := server.Accept() // Accept the connection only to close it immediately.
conn, err := lis.Accept() // Accept the connection only to close it immediately.
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
@@ -289,7 +323,7 @@ func TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
var prevDuration time.Duration
// Make sure the retry attempts are backed off properly.
for i := 0; i < 3; i++ {
conn, err := server.Accept()
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
@@ -305,7 +339,7 @@ func TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
prevAt = meow
}
}()
client, err := Dial(server.Addr().String(), WithInsecure())
client, err := Dial(lis.Addr().String(), WithInsecure())
if err != nil {
t.Fatalf("Error while dialing. Err: %v", err)
}
@@ -351,7 +385,7 @@ func TestConnectivityStates(t *testing.T) {
}
func TestDialTimeout(t *testing.T) {
func TestWithTimeout(t *testing.T) {
defer leakcheck.Check(t)
conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure())
if err == nil {
@@ -362,13 +396,15 @@ func TestDialTimeout(t *testing.T) {
}
}
func TestTLSDialTimeout(t *testing.T) {
func TestWithTransportCredentialsTLS(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
defer leakcheck.Check(t)
creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock())
conn, err := DialContext(ctx, "passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithBlock())
if err == nil {
conn.Close()
}
@@ -446,6 +482,26 @@ func TestDialContextCancel(t *testing.T) {
}
}
type failFastError struct{}
func (failFastError) Error() string { return "failfast" }
func (failFastError) Temporary() bool { return false }
func TestDialContextFailFast(t *testing.T) {
defer leakcheck.Check(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
failErr := failFastError{}
dialer := func(string, time.Duration) (net.Conn, error) {
return nil, failErr
}
_, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure(), WithDialer(dialer), FailOnNonTempDialError(true))
if terr, ok := err.(transport.ConnectionError); !ok || terr.Origin() != failErr {
t.Fatalf("DialContext() = _, %v, want _, %v", err, failErr)
}
}
// blockingBalancer mimics the behavior of balancers whose initialization takes a long time.
// In this test, reading from blockingBalancer.Notify() blocks forever.
type blockingBalancer struct {
@@ -520,7 +576,6 @@ func TestWithBackoffConfig(t *testing.T) {
defer leakcheck.Check(t)
b := BackoffConfig{MaxDelay: DefaultBackoffConfig.MaxDelay / 2}
expected := b
setDefaults(&expected) // defaults should be set
testBackoffConfigSet(t, &expected, WithBackoffConfig(b))
}
@@ -528,7 +583,6 @@ func TestWithBackoffMaxDelay(t *testing.T) {
defer leakcheck.Check(t)
md := DefaultBackoffConfig.MaxDelay / 2
expected := BackoffConfig{MaxDelay: md}
setDefaults(&expected)
testBackoffConfigSet(t, &expected, WithBackoffMaxDelay(md))
}
@@ -544,12 +598,15 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt
t.Fatalf("backoff config not set")
}
actual, ok := conn.dopts.bs.(BackoffConfig)
actual, ok := conn.dopts.bs.(backoff.Exponential)
if !ok {
t.Fatalf("unexpected type of backoff config: %#v", conn.dopts.bs)
}
if actual != *expected {
expectedValue := backoff.Exponential{
MaxDelay: expected.MaxDelay,
}
if actual != expectedValue {
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
}
}
@@ -617,6 +674,22 @@ func TestResolverServiceConfigBeforeAddressNotPanic(t *testing.T) {
time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
}
func TestResolverServiceConfigWhileClosingNotPanic(t *testing.T) {
defer leakcheck.Check(t)
for i := 0; i < 10; i++ { // Run this multiple times to make sure it doesn't panic.
r, rcleanup := manual.GenerateAndRegisterManualResolver()
defer rcleanup()
cc, err := Dial(r.Scheme()+":///test.server", WithInsecure())
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
// Send a new service config while closing the ClientConn.
go cc.Close()
go r.NewServiceConfig(`{"loadBalancingPolicy": "round_robin"}`) // This should not panic.
}
}
func TestResolverEmptyUpdateNotPanic(t *testing.T) {
defer leakcheck.Check(t)
r, rcleanup := manual.GenerateAndRegisterManualResolver()
@@ -662,3 +735,102 @@ func TestClientUpdatesParamsAfterGoAway(t *testing.T) {
t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 100ms", v)
}
}
func TestDisableServiceConfigOption(t *testing.T) {
r, cleanup := manual.GenerateAndRegisterManualResolver()
defer cleanup()
addr := r.Scheme() + ":///non.existent"
cc, err := Dial(addr, WithInsecure(), WithDisableServiceConfig())
if err != nil {
t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
}
defer cc.Close()
r.NewServiceConfig(`{
"methodConfig": [
{
"name": [
{
"service": "foo",
"method": "Bar"
}
],
"waitForReady": true
}
]
}`)
time.Sleep(1 * time.Second)
m := cc.GetMethodConfig("/foo/Bar")
if m.WaitForReady != nil {
t.Fatalf("want: method (\"/foo/bar/\") config to be empty, got: %v", m)
}
}
func TestGetClientConnTarget(t *testing.T) {
addr := "nonexist:///non.existent"
cc, err := Dial(addr, WithInsecure())
if err != nil {
t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
}
defer cc.Close()
if cc.Target() != addr {
t.Fatalf("Target() = %s, want %s", cc.Target(), addr)
}
}
type backoffForever struct{}
func (b backoffForever) Backoff(int) time.Duration { return time.Duration(math.MaxInt64) }
func TestResetConnectBackoff(t *testing.T) {
defer leakcheck.Check(t)
dials := make(chan struct{})
defer func() { // If we fail, let the http2client break out of dialing.
select {
case <-dials:
default:
}
}()
dialer := func(string, time.Duration) (net.Conn, error) {
dials <- struct{}{}
return nil, errors.New("failed to fake dial")
}
cc, err := Dial("any", WithInsecure(), WithDialer(dialer), withBackoff(backoffForever{}))
if err != nil {
t.Fatalf("Dial() = _, %v; want _, nil", err)
}
defer cc.Close()
select {
case <-dials:
case <-time.NewTimer(10 * time.Second).C:
t.Fatal("Failed to call dial within 10s")
}
select {
case <-dials:
t.Fatal("Dial called unexpectedly before resetting backoff")
case <-time.NewTimer(100 * time.Millisecond).C:
}
cc.ResetConnectBackoff()
select {
case <-dials:
case <-time.NewTimer(10 * time.Second).C:
t.Fatal("Failed to call dial within 10s after resetting backoff")
}
}
func TestBackoffCancel(t *testing.T) {
defer leakcheck.Check(t)
dialStrCh := make(chan string)
cc, err := Dial("any", WithInsecure(), WithDialer(func(t string, _ time.Duration) (net.Conn, error) {
dialStrCh <- t
return nil, fmt.Errorf("test dialer, always error")
}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
<-dialStrCh
cc.Close()
// Should not leak. May need -count 5000 to exercise.
}

View File

@@ -22,6 +22,7 @@ package codes // import "google.golang.org/grpc/codes"
import (
"fmt"
"strconv"
)
// A Code is an unsigned 32-bit error code as defined in the gRPC spec.
@@ -143,6 +144,8 @@ const (
// Unauthenticated indicates the request does not have valid
// authentication credentials for the operation.
Unauthenticated Code = 16
_maxCode = 17
)
var strToCode = map[string]Code{
@@ -176,6 +179,16 @@ func (c *Code) UnmarshalJSON(b []byte) error {
if c == nil {
return fmt.Errorf("nil receiver passed to UnmarshalJSON")
}
if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil {
if ci >= _maxCode {
return fmt.Errorf("invalid code: %q", ci)
}
*c = Code(ci)
return nil
}
if jc, ok := strToCode[string(b)]; ok {
*c = jc
return nil

View File

@@ -62,3 +62,23 @@ func TestUnmarshalJSON_UnknownInput(t *testing.T) {
}
}
}
func TestUnmarshalJSON_MarshalUnmarshal(t *testing.T) {
for i := 0; i < _maxCode; i++ {
var cUnMarshaled Code
c := Code(i)
cJSON, err := json.Marshal(c)
if err != nil {
t.Errorf("marshalling %q failed: %v", c, err)
}
if err := json.Unmarshal(cJSON, &cUnMarshaled); err != nil {
t.Errorf("unmarshalling code failed: %s", err)
}
if c != cUnMarshaled {
t.Errorf("code is %q after marshalling/unmarshalling, expected %q", cUnMarshaled, c)
}
}
}

330
vendor/google.golang.org/grpc/credentials/alts/alts.go generated vendored Normal file
View File

@@ -0,0 +1,330 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package alts implements the ALTS credential support by gRPC library, which
// encapsulates all the state needed by a client to authenticate with a server
// using ALTS and make various assertions, e.g., about the client's identity,
// role, or whether it is authorized to make a particular call.
// This package is experimental.
package alts
import (
"errors"
"fmt"
"net"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/credentials/alts/internal/handshaker"
"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/grpclog"
)
const (
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
// handshaker service address in the hypervisor.
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
// defaultTimeout specifies the server handshake timeout.
defaultTimeout = 30.0 * time.Second
// The following constants specify the minimum and maximum acceptable
// protocol versions.
protocolVersionMaxMajor = 2
protocolVersionMaxMinor = 1
protocolVersionMinMajor = 2
protocolVersionMinMinor = 1
)
var (
once sync.Once
maxRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMaxMajor,
Minor: protocolVersionMaxMinor,
}
minRPCVersion = &altspb.RpcProtocolVersions_Version{
Major: protocolVersionMinMajor,
Minor: protocolVersionMinMinor,
}
// ErrUntrustedPlatform is returned from ClientHandshake and
// ServerHandshake is running on a platform where the trustworthiness of
// the handshaker service is not guaranteed.
ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
)
// AuthInfo exposes security information from the ALTS handshake to the
// application. This interface is to be implemented by ALTS. Users should not
// need a brand new implementation of this interface. For situations like
// testing, any new implementation should embed this interface. This allows
// ALTS to add new methods to this interface.
type AuthInfo interface {
// ApplicationProtocol returns application protocol negotiated for the
// ALTS connection.
ApplicationProtocol() string
// RecordProtocol returns the record protocol negotiated for the ALTS
// connection.
RecordProtocol() string
// SecurityLevel returns the security level of the created ALTS secure
// channel.
SecurityLevel() altspb.SecurityLevel
// PeerServiceAccount returns the peer service account.
PeerServiceAccount() string
// LocalServiceAccount returns the local service account.
LocalServiceAccount() string
// PeerRPCVersions returns the RPC version supported by the peer.
PeerRPCVersions() *altspb.RpcProtocolVersions
}
// ClientOptions contains the client-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ClientOptions struct {
// TargetServiceAccounts contains a list of expected target service
// accounts.
TargetServiceAccounts []string
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultClientOptions creates a new ClientOptions object with the default
// values.
func DefaultClientOptions() *ClientOptions {
return &ClientOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// ServerOptions contains the server-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ServerOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultServerOptions creates a new ServerOptions object with the default
// values.
func DefaultServerOptions() *ServerOptions {
return &ServerOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface.
type altsTC struct {
info *credentials.ProtocolInfo
side core.Side
accounts []string
hsAddress string
}
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
}
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
}
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
once.Do(func() {
vmOnGCP = isRunningOnGCP()
})
if hsAddress == "" {
hsAddress = hypervisorHandshakerServiceAddress
}
return &altsTC{
info: &credentials.ProtocolInfo{
SecurityProtocol: "alts",
SecurityVersion: "1.0",
},
side: side,
accounts: accounts,
hsAddress: hsAddress,
}
}
// ClientHandshake implements the client side handshake protocol.
func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !vmOnGCP {
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
// Do not close hsConn since it is shared with other handshakes.
// Possible context leak:
// The cancel function for the child context we create will only be
// called a non-nil error is returned.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
opts := handshaker.DefaultClientHandshakerOptions()
opts.TargetName = addr
opts.TargetServiceAccounts = g.accounts
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
chs.Close()
}
}()
if err != nil {
return nil, nil, err
}
secConn, authInfo, err := chs.ClientHandshake(ctx)
if err != nil {
return nil, nil, err
}
altsAuthInfo, ok := authInfo.(AuthInfo)
if !ok {
return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
}
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
if !match {
return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
}
return secConn, authInfo, nil
}
// ServerHandshake implements the server side ALTS handshaker.
func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
if !vmOnGCP {
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
// Do not close hsConn since it's shared with other handshakes.
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
opts := handshaker.DefaultServerHandshakerOptions()
opts.RPCVersions = &altspb.RpcProtocolVersions{
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
defer func() {
if err != nil {
shs.Close()
}
}()
if err != nil {
return nil, nil, err
}
secConn, authInfo, err := shs.ServerHandshake(ctx)
if err != nil {
return nil, nil, err
}
altsAuthInfo, ok := authInfo.(AuthInfo)
if !ok {
return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
}
match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
if !match {
return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
}
return secConn, authInfo, nil
}
func (g *altsTC) Info() credentials.ProtocolInfo {
return *g.info
}
func (g *altsTC) Clone() credentials.TransportCredentials {
info := *g.info
var accounts []string
if g.accounts != nil {
accounts = make([]string, len(g.accounts))
copy(accounts, g.accounts)
}
return &altsTC{
info: &info,
side: g.side,
hsAddress: g.hsAddress,
accounts: accounts,
}
}
func (g *altsTC) OverrideServerName(serverNameOverride string) error {
g.info.ServerName = serverNameOverride
return nil
}
// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
switch {
case v1.GetMajor() > v2.GetMajor(),
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
return 1
case v1.GetMajor() < v2.GetMajor(),
v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
return -1
}
return 0
}
// checkRPCVersions performs a version check between local and peer rpc protocol
// versions. This function returns true if the check passes which means both
// parties agreed on a common rpc protocol to use, and false otherwise. The
// function also returns the highest common RPC protocol version both parties
// agreed on.
func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
if local == nil || peer == nil {
grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.")
return false, nil
}
// maxCommonVersion is MIN(local.max, peer.max).
maxCommonVersion := local.GetMaxRpcVersion()
if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
maxCommonVersion = peer.GetMaxRpcVersion()
}
// minCommonVersion is MAX(local.min, peer.min).
minCommonVersion := peer.GetMinRpcVersion()
if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
minCommonVersion = local.GetMinRpcVersion()
}
if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
return false, nil
}
return true, maxCommonVersion
}

View File

@@ -0,0 +1,290 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"reflect"
"testing"
"github.com/golang/protobuf/proto"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
func TestInfoServerName(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
alts := NewServerCreds(DefaultServerOptions())
if got, want := alts.Info().ServerName, ""; got != want {
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
}
}
func TestOverrideServerName(t *testing.T) {
wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
}
}
func TestCloneClient(t *testing.T) {
wantServerName := "server.name"
opt := DefaultClientOptions()
opt.TargetServiceAccounts = []string{"not", "empty"}
c := NewClientCreds(opt)
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
cc.OverrideServerName("")
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
}
if got, want := cc.Info().ServerName, ""; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
ct := c.(*altsTC)
cct := cc.(*altsTC)
if ct.side != cct.side {
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
}
if ct.hsAddress != cct.hsAddress {
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
}
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
}
}
func TestCloneServer(t *testing.T) {
wantServerName := "server.name"
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
cc.OverrideServerName("")
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
}
if got, want := cc.Info().ServerName, ""; got != want {
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
}
ct := c.(*altsTC)
cct := cc.(*altsTC)
if ct.side != cct.side {
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
}
if ct.hsAddress != cct.hsAddress {
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
}
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
}
}
func TestInfo(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds(DefaultServerOptions())
info := c.Info()
if got, want := info.ProtocolVersion, ""; got != want {
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
}
if got, want := info.SecurityProtocol, "alts"; got != want {
t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
}
if got, want := info.SecurityVersion, "1.0"; got != want {
t.Errorf("info.SecurityVersion=%v, want %v", got, want)
}
if got, want := info.ServerName, ""; got != want {
t.Errorf("info.ServerName=%v, want %v", got, want)
}
}
func TestCompareRPCVersions(t *testing.T) {
for _, tc := range []struct {
v1 *altspb.RpcProtocolVersions_Version
v2 *altspb.RpcProtocolVersions_Version
output int
}{
{
version(3, 2),
version(2, 1),
1,
},
{
version(3, 2),
version(3, 1),
1,
},
{
version(2, 1),
version(3, 2),
-1,
},
{
version(3, 1),
version(3, 2),
-1,
},
{
version(3, 2),
version(3, 2),
0,
},
} {
if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
}
}
}
func TestCheckRPCVersions(t *testing.T) {
for _, tc := range []struct {
desc string
local *altspb.RpcProtocolVersions
peer *altspb.RpcProtocolVersions
output bool
maxCommonVersion *altspb.RpcProtocolVersions_Version
}{
{
"local.max > peer.max and local.min > peer.min",
versions(2, 1, 3, 2),
versions(1, 2, 2, 1),
true,
version(2, 1),
},
{
"local.max > peer.max and local.min < peer.min",
versions(1, 2, 3, 2),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"local.max > peer.max and local.min = peer.min",
versions(2, 1, 3, 2),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min > peer.min",
versions(2, 1, 2, 1),
versions(1, 2, 3, 2),
true,
version(2, 1),
},
{
"local.max = peer.max and local.min > peer.min",
versions(2, 1, 2, 1),
versions(1, 2, 2, 1),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min < peer.min",
versions(1, 2, 2, 1),
versions(2, 1, 3, 2),
true,
version(2, 1),
},
{
"local.max < peer.max and local.min = peer.min",
versions(1, 2, 2, 1),
versions(1, 2, 3, 2),
true,
version(2, 1),
},
{
"local.max = peer.max and local.min < peer.min",
versions(1, 2, 2, 1),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"all equal",
versions(2, 1, 2, 1),
versions(2, 1, 2, 1),
true,
version(2, 1),
},
{
"max is smaller than min",
versions(2, 1, 1, 2),
versions(2, 1, 1, 2),
false,
nil,
},
{
"no overlap, local > peer",
versions(4, 3, 6, 5),
versions(1, 0, 2, 1),
false,
nil,
},
{
"no overlap, local < peer",
versions(1, 0, 2, 1),
versions(4, 3, 6, 5),
false,
nil,
},
{
"no overlap, max < min",
versions(6, 5, 4, 3),
versions(2, 1, 1, 0),
false,
nil,
},
} {
output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
if got, want := output, tc.output; got != want {
t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
}
if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
}
}
}
func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
return &altspb.RpcProtocolVersions_Version{
Major: major,
Minor: minor,
}
}
func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
return &altspb.RpcProtocolVersions{
MinRpcVersion: version(minMajor, minMinor),
MaxRpcVersion: version(maxMajor, maxMinor),
}
}

View File

@@ -0,0 +1,87 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package authinfo provide authentication information returned by handshakers.
package authinfo
import (
"google.golang.org/grpc/credentials"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
var _ credentials.AuthInfo = (*altsAuthInfo)(nil)
// altsAuthInfo exposes security information from the ALTS handshake to the
// application. altsAuthInfo is immutable and implements credentials.AuthInfo.
type altsAuthInfo struct {
p *altspb.AltsContext
}
// New returns a new altsAuthInfo object given handshaker results.
func New(result *altspb.HandshakerResult) credentials.AuthInfo {
return newAuthInfo(result)
}
func newAuthInfo(result *altspb.HandshakerResult) *altsAuthInfo {
return &altsAuthInfo{
p: &altspb.AltsContext{
ApplicationProtocol: result.GetApplicationProtocol(),
RecordProtocol: result.GetRecordProtocol(),
// TODO: assign security level from result.
SecurityLevel: altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
PeerServiceAccount: result.GetPeerIdentity().GetServiceAccount(),
LocalServiceAccount: result.GetLocalIdentity().GetServiceAccount(),
PeerRpcVersions: result.GetPeerRpcVersions(),
},
}
}
// AuthType identifies the context as providing ALTS authentication information.
func (s *altsAuthInfo) AuthType() string {
return "alts"
}
// ApplicationProtocol returns the context's application protocol.
func (s *altsAuthInfo) ApplicationProtocol() string {
return s.p.GetApplicationProtocol()
}
// RecordProtocol returns the context's record protocol.
func (s *altsAuthInfo) RecordProtocol() string {
return s.p.GetRecordProtocol()
}
// SecurityLevel returns the context's security level.
func (s *altsAuthInfo) SecurityLevel() altspb.SecurityLevel {
return s.p.GetSecurityLevel()
}
// PeerServiceAccount returns the context's peer service account.
func (s *altsAuthInfo) PeerServiceAccount() string {
return s.p.GetPeerServiceAccount()
}
// LocalServiceAccount returns the context's local service account.
func (s *altsAuthInfo) LocalServiceAccount() string {
return s.p.GetLocalServiceAccount()
}
// PeerRPCVersions returns the context's peer RPC versions.
func (s *altsAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions {
return s.p.GetPeerRpcVersions()
}

View File

@@ -0,0 +1,134 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package authinfo
import (
"reflect"
"testing"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
const (
testAppProtocol = "my_app"
testRecordProtocol = "very_secure_protocol"
testPeerAccount = "peer_service_account"
testLocalAccount = "local_service_account"
testPeerHostname = "peer_hostname"
testLocalHostname = "local_hostname"
)
func TestALTSAuthInfo(t *testing.T) {
for _, tc := range []struct {
result *altspb.HandshakerResult
outAppProtocol string
outRecordProtocol string
outSecurityLevel altspb.SecurityLevel
outPeerAccount string
outLocalAccount string
outPeerRPCVersions *altspb.RpcProtocolVersions
}{
{
&altspb.HandshakerResult{
ApplicationProtocol: testAppProtocol,
RecordProtocol: testRecordProtocol,
PeerIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: testPeerAccount,
},
},
LocalIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: testLocalAccount,
},
},
},
testAppProtocol,
testRecordProtocol,
altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
testPeerAccount,
testLocalAccount,
nil,
},
{
&altspb.HandshakerResult{
ApplicationProtocol: testAppProtocol,
RecordProtocol: testRecordProtocol,
PeerIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: testPeerHostname,
},
},
LocalIdentity: &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: testLocalHostname,
},
},
PeerRpcVersions: &altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 20,
Minor: 21,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 10,
Minor: 11,
},
},
},
testAppProtocol,
testRecordProtocol,
altspb.SecurityLevel_INTEGRITY_AND_PRIVACY,
"",
"",
&altspb.RpcProtocolVersions{
MaxRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 20,
Minor: 21,
},
MinRpcVersion: &altspb.RpcProtocolVersions_Version{
Major: 10,
Minor: 11,
},
},
},
} {
authInfo := newAuthInfo(tc.result)
if got, want := authInfo.AuthType(), "alts"; got != want {
t.Errorf("authInfo.AuthType()=%v, want %v", got, want)
}
if got, want := authInfo.ApplicationProtocol(), tc.outAppProtocol; got != want {
t.Errorf("authInfo.ApplicationProtocol()=%v, want %v", got, want)
}
if got, want := authInfo.RecordProtocol(), tc.outRecordProtocol; got != want {
t.Errorf("authInfo.RecordProtocol()=%v, want %v", got, want)
}
if got, want := authInfo.SecurityLevel(), tc.outSecurityLevel; got != want {
t.Errorf("authInfo.SecurityLevel()=%v, want %v", got, want)
}
if got, want := authInfo.PeerServiceAccount(), tc.outPeerAccount; got != want {
t.Errorf("authInfo.PeerServiceAccount()=%v, want %v", got, want)
}
if got, want := authInfo.LocalServiceAccount(), tc.outLocalAccount; got != want {
t.Errorf("authInfo.LocalServiceAccount()=%v, want %v", got, want)
}
if got, want := authInfo.PeerRPCVersions(), tc.outPeerRPCVersions; !reflect.DeepEqual(got, want) {
t.Errorf("authinfo.PeerRpcVersions()=%v, want %v", got, want)
}
}
}

View File

@@ -0,0 +1,69 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
//go:generate ./regenerate.sh
// Package internal contains common core functionality for ALTS.
package internal
import (
"net"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
)
const (
// ClientSide identifies the client in this communication.
ClientSide Side = iota
// ServerSide identifies the server in this communication.
ServerSide
)
// PeerNotRespondingError is returned when a peer server is not responding
// after a channel has been established. It is treated as a temporary connection
// error and re-connection to the server should be attempted.
var PeerNotRespondingError = &peerNotRespondingError{}
// Side identifies the party's role: client or server.
type Side int
type peerNotRespondingError struct{}
// Return an error message for the purpose of logging.
func (e *peerNotRespondingError) Error() string {
return "peer server is not responding and re-connection should be attempted."
}
// Temporary indicates if this connection error is temporary or fatal.
func (e *peerNotRespondingError) Temporary() bool {
return true
}
// Handshaker defines a ALTS handshaker interface.
type Handshaker interface {
// ClientHandshake starts and completes a client-side handshaking and
// returns a secure connection and corresponding auth information.
ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
// ServerHandshake starts and completes a server-side handshaking and
// returns a secure connection and corresponding auth information.
ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
// Close terminates the Handshaker. It should be called when the caller
// obtains the secure connection.
Close()
}

View File

@@ -0,0 +1,131 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"fmt"
"strconv"
)
// rekeyAEAD holds the necessary information for an AEAD based on
// AES-GCM that performs nonce-based key derivation and XORs the
// nonce with a random mask.
type rekeyAEAD struct {
kdfKey []byte
kdfCounter []byte
nonceMask []byte
nonceBuf []byte
gcmAEAD cipher.AEAD
}
// KeySizeError signals that the given key does not have the correct size.
type KeySizeError int
func (k KeySizeError) Error() string {
return "alts/conn: invalid key size " + strconv.Itoa(int(k))
}
// newRekeyAEAD creates a new instance of aes128gcm with rekeying.
// The key argument should be 44 bytes, the first 32 bytes are used as a key
// for HKDF-expand and the remainining 12 bytes are used as a random mask for
// the counter.
func newRekeyAEAD(key []byte) (*rekeyAEAD, error) {
k := len(key)
if k != kdfKeyLen+nonceLen {
return nil, KeySizeError(k)
}
return &rekeyAEAD{
kdfKey: key[:kdfKeyLen],
kdfCounter: make([]byte, kdfCounterLen),
nonceMask: key[kdfKeyLen:],
nonceBuf: make([]byte, nonceLen),
gcmAEAD: nil,
}, nil
}
// Seal rekeys if nonce[2:8] is different than in the last call, masks the nonce,
// and calls Seal for aes128gcm.
func (s *rekeyAEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if err := s.rekeyIfRequired(nonce); err != nil {
panic(fmt.Sprintf("Rekeying failed with: %s", err.Error()))
}
maskNonce(s.nonceBuf, nonce, s.nonceMask)
return s.gcmAEAD.Seal(dst, s.nonceBuf, plaintext, additionalData)
}
// Open rekeys if nonce[2:8] is different than in the last call, masks the nonce,
// and calls Open for aes128gcm.
func (s *rekeyAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if err := s.rekeyIfRequired(nonce); err != nil {
return nil, err
}
maskNonce(s.nonceBuf, nonce, s.nonceMask)
return s.gcmAEAD.Open(dst, s.nonceBuf, ciphertext, additionalData)
}
// rekeyIfRequired creates a new aes128gcm AEAD if the existing AEAD is nil
// or cannot be used with given nonce.
func (s *rekeyAEAD) rekeyIfRequired(nonce []byte) error {
newKdfCounter := nonce[kdfCounterOffset : kdfCounterOffset+kdfCounterLen]
if s.gcmAEAD != nil && bytes.Equal(newKdfCounter, s.kdfCounter) {
return nil
}
copy(s.kdfCounter, newKdfCounter)
a, err := aes.NewCipher(hkdfExpand(s.kdfKey, s.kdfCounter))
if err != nil {
return err
}
s.gcmAEAD, err = cipher.NewGCM(a)
return err
}
// maskNonce XORs the given nonce with the mask and stores the result in dst.
func maskNonce(dst, nonce, mask []byte) {
nonce1 := binary.LittleEndian.Uint64(nonce[:sizeUint64])
nonce2 := binary.LittleEndian.Uint32(nonce[sizeUint64:])
mask1 := binary.LittleEndian.Uint64(mask[:sizeUint64])
mask2 := binary.LittleEndian.Uint32(mask[sizeUint64:])
binary.LittleEndian.PutUint64(dst[:sizeUint64], nonce1^mask1)
binary.LittleEndian.PutUint32(dst[sizeUint64:], nonce2^mask2)
}
// NonceSize returns the required nonce size.
func (s *rekeyAEAD) NonceSize() int {
return s.gcmAEAD.NonceSize()
}
// Overhead returns the ciphertext overhead.
func (s *rekeyAEAD) Overhead() int {
return s.gcmAEAD.Overhead()
}
// hkdfExpand computes the first 16 bytes of the HKDF-expand function
// defined in RFC5869.
func hkdfExpand(key, info []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(info)
mac.Write([]byte{0x01}[:])
return mac.Sum(nil)[:aeadKeyLen]
}

View File

@@ -0,0 +1,263 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"encoding/hex"
"testing"
)
// cryptoTestVector is struct for a rekey test vector
type rekeyAEADTestVector struct {
desc string
key, nonce, plaintext, aad, ciphertext []byte
}
// Test encrypt and decrypt using (adapted) test vectors for AES-GCM.
func TestAES128GCMRekeyEncrypt(t *testing.T) {
for _, test := range []rekeyAEADTestVector{
// NIST vectors from:
// http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
//
// IEEE vectors from:
// http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf
//
// Key expanded by setting
// expandedKey = (key ||
// key ^ {0x01,..,0x01} ||
// key ^ {0x02,..,0x02})[0:44].
{
desc: "Derived from NIST test vector 1",
key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"),
nonce: dehex("000000000000000000000000"),
aad: dehex(""),
plaintext: dehex(""),
ciphertext: dehex("85e873e002f6ebdc4060954eb8675508"),
},
{
desc: "Derived from NIST test vector 2",
key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"),
nonce: dehex("000000000000000000000000"),
aad: dehex(""),
plaintext: dehex("00000000000000000000000000000000"),
ciphertext: dehex("51e9a8cb23ca2512c8256afff8e72d681aca19a1148ac115e83df4888cc00d11"),
},
{
desc: "Derived from NIST test vector 3",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddecaf888"),
aad: dehex(""),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255"),
ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c0df1162129952213cee1bc6e9c8495dd705e1f3d"),
},
{
desc: "Derived from NIST test vector 4",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c4764565d077e9124001ddb27fc0848c5"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 15)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("ca7ebabefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("e650d3c0fb879327f2d03287fa93cd07342b136215adbca00c3bd5099ec41832b1d18e0423ed26bb12c6cd09debb29230a94c0cee15903656f85edb6fc509b1b28216382172ecbcc31e1e9b1"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 16)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebbbefacedbaddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("c0121e6c954d0767f96630c33450999791b2da2ad05c4190169ccad9ac86ff1c721e3d82f2ad22ab463bab4a0754b7dd68ca4de7ea2531b625eda01f89312b2ab957d5c7f8568dd95fcdcd1f"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 63)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedb2ddecaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("8af37ea5684a4d81d4fd817261fd9743099e7e6a025eaacf8e54b124fb5743149e05cb89f4a49467fe2e5e5965f29a19f99416b0016b54585d12553783ba59e9f782e82e097c336bf7989f08"),
},
{
desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 64)",
key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"),
nonce: dehex("cafebabefacedbaddfcaf888"),
aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"),
plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"),
ciphertext: dehex("fbd528448d0346bfa878634864d407a35a039de9db2f1feb8e965b3ae9356ce6289441d77f8f0df294891f37ea438b223e3bf2bdc53d4c5a74fb680bb312a8dec6f7252cbcd7f5799750ad78"),
},
{
desc: "Derived from IEEE 2.1.1 54-byte auth",
key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"),
plaintext: dehex(""),
ciphertext: dehex("3ea0b584f3c85e93f9320ea591699efb"),
},
{
desc: "Derived from IEEE 2.1.2 54-byte auth",
key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"),
plaintext: dehex(""),
ciphertext: dehex("294e028bf1fe6f14c4e8f7305c933eb5"),
},
{
desc: "Derived from IEEE 2.2.1 60-byte crypt",
key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"),
ciphertext: dehex("db3d25719c6b0a3ca6145c159d5c6ed9aff9c6e0b79f17019ea923b8665ddf52137ad611f0d1bf417a7ca85e45afe106ff9c7569d335d086ae6c03f00987ccd6"),
},
{
desc: "Derived from IEEE 2.2.2 60-byte crypt",
key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"),
nonce: dehex("12153524c0895e81b2c28465"),
aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"),
ciphertext: dehex("1641f28ec13afcc8f7903389787201051644914933e9202bb9d06aa020c2a67ef51dfe7bc00a856c55b8f8133e77f659132502bad63f5713d57d0c11e0f871ed"),
},
{
desc: "Derived from IEEE 2.3.1 60-byte auth",
key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"),
plaintext: dehex(""),
ciphertext: dehex("58837a10562b0f1f8edbe58ca55811d3"),
},
{
desc: "Derived from IEEE 2.3.2 60-byte auth",
key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"),
plaintext: dehex(""),
ciphertext: dehex("c2722ff6ca29a257718a529d1f0c6a3b"),
},
{
desc: "Derived from IEEE 2.4.1 54-byte crypt",
key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"),
ciphertext: dehex("fd96b715b93a13346af51e8acdf792cdc7b2686f8574c70e6b0cbf16291ded427ad73fec48cd298e0528a1f4c644a949fc31dc9279706ddba33f"),
},
{
desc: "Derived from IEEE 2.4.2 54-byte crypt",
key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"),
nonce: dehex("f0761e8dcd3d000176d457ed"),
aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"),
ciphertext: dehex("b68f6300c2e9ae833bdc070e24021a3477118e78ccf84e11a485d861476c300f175353d5cdf92008a4f878e6cc3577768085c50a0e98fda6cbb8"),
},
{
desc: "Derived from IEEE 2.5.1 65-byte auth",
key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"),
plaintext: dehex(""),
ciphertext: dehex("cca20eecda6283f09bb3543dd99edb9b"),
},
{
desc: "Derived from IEEE 2.5.2 65-byte auth",
key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"),
plaintext: dehex(""),
ciphertext: dehex("b232cc1da5117bf15003734fa599d271"),
},
{
desc: "Derived from IEEE 2.6.1 61-byte crypt",
key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"),
ciphertext: dehex("ff1910d35ad7e5657890c7c560146fd038707f204b66edbc3d161f8ace244b985921023c436e3a1c3532ecd5d09a056d70be583f0d10829d9387d07d33d872e490"),
},
{
desc: "Derived from IEEE 2.6.2 61-byte crypt",
key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"),
nonce: dehex("7cfde9f9e33724c68932d612"),
aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"),
ciphertext: dehex("0db4cf956b5f97eca4eab82a6955307f9ae02a32dd7d93f83d66ad04e1cfdc5182ad12abdea5bbb619a1bd5fb9a573590fba908e9c7a46c1f7ba0905d1b55ffda4"),
},
{
desc: "Derived from IEEE 2.7.1 79-byte crypt",
key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"),
plaintext: dehex(""),
ciphertext: dehex("813f0e630f96fb2d030f58d83f5cdfd0"),
},
{
desc: "Derived from IEEE 2.7.2 79-byte crypt",
key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"),
plaintext: dehex(""),
ciphertext: dehex("77e5a44c21eb07188aacbd74d1980e97"),
},
{
desc: "Derived from IEEE 2.8.1 61-byte crypt",
key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"),
ciphertext: dehex("958ec3f6d60afeda99efd888f175e5fcd4c87b9bcc5c2f5426253a8b506296c8c43309ab2adb5939462541d95e80811e04e706b1498f2c407c7fb234f8cc01a647550ee6b557b35a7e3945381821f4"),
},
{
desc: "Derived from IEEE 2.8.2 61-byte crypt",
key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"),
nonce: dehex("7ae8e2ca4ec500012e58495c"),
aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"),
plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"),
ciphertext: dehex("b44d072011cd36d272a9b7a98db9aa90cbc5c67b93ddce67c854503214e2e896ec7e9db649ed4bcf6f850aac0223d0cf92c83db80795c3a17ecc1248bb00591712b1ae71e268164196252162810b00"),
}} {
aead, err := newRekeyAEAD(test.key)
if err != nil {
t.Fatal("unexpected failure in newRekeyAEAD: ", err.Error())
}
if got := aead.Seal(nil, test.nonce, test.plaintext, test.aad); !bytes.Equal(got, test.ciphertext) {
t.Errorf("Unexpected ciphertext for test vector '%s':\nciphertext=%s\nwant= %s",
test.desc, hex.EncodeToString(got), hex.EncodeToString(test.ciphertext))
}
if got, err := aead.Open(nil, test.nonce, test.ciphertext, test.aad); err != nil || !bytes.Equal(got, test.plaintext) {
t.Errorf("Unexpected plaintext for test vector '%s':\nplaintext=%s (err=%v)\nwant= %s",
test.desc, hex.EncodeToString(got), err, hex.EncodeToString(test.plaintext))
}
}
}
func dehex(s string) []byte {
if len(s) == 0 {
return make([]byte, 0)
}
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

View File

@@ -0,0 +1,105 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"crypto/aes"
"crypto/cipher"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
// Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in
// each direction).
overflowLenAES128GCM = 5
)
// aes128gcm is the struct that holds necessary information for ALTS record.
// The counter value is NOT included in the payload during the encryption and
// decryption operations.
type aes128gcm struct {
// inCounter is used in ALTS record to check that incoming counters are
// as expected, since ALTS record guarantees that messages are unwrapped
// in the same order that the peer wrapped them.
inCounter Counter
outCounter Counter
aead cipher.AEAD
}
// NewAES128GCM creates an instance that uses aes128gcm for ALTS record.
func NewAES128GCM(side core.Side, key []byte) (ALTSRecordCrypto, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
a, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return &aes128gcm{
inCounter: NewInCounter(side, overflowLenAES128GCM),
outCounter: NewOutCounter(side, overflowLenAES128GCM),
aead: a,
}, nil
}
// Encrypt is the encryption function. dst can contain bytes at the beginning of
// the ciphertext that will not be encrypted but will be authenticated. If dst
// has enough capacity to hold these bytes, the ciphertext and the tag, no
// allocation and copy operations will be performed. dst and plaintext do not
// overlap.
func (s *aes128gcm) Encrypt(dst, plaintext []byte) ([]byte, error) {
// If we need to allocate an output buffer, we want to include space for
// GCM tag to avoid forcing ALTS record to reallocate as well.
dlen := len(dst)
dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize)
seq, err := s.outCounter.Value()
if err != nil {
return nil, err
}
data := out[:len(plaintext)]
copy(data, plaintext) // data may alias plaintext
// Seal appends the ciphertext and the tag to its first argument and
// returns the updated slice. However, SliceForAppend above ensures that
// dst has enough capacity to avoid a reallocation and copy due to the
// append.
dst = s.aead.Seal(dst[:dlen], seq, data, nil)
s.outCounter.Inc()
return dst, nil
}
func (s *aes128gcm) EncryptionOverhead() int {
return GcmTagSize
}
func (s *aes128gcm) Decrypt(dst, ciphertext []byte) ([]byte, error) {
seq, err := s.inCounter.Value()
if err != nil {
return nil, err
}
// If dst is equal to ciphertext[:0], ciphertext storage is reused.
plaintext, err := s.aead.Open(dst, seq, ciphertext, nil)
if err != nil {
return nil, ErrAuth
}
s.inCounter.Inc()
return plaintext, nil
}

View File

@@ -0,0 +1,223 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
// cryptoTestVector is struct for a GCM test vector
type cryptoTestVector struct {
key, counter, plaintext, ciphertext, tag []byte
allocateDst bool
}
// getGCMCryptoPair outputs a client/server pair on aes128gcm.
func getGCMCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) {
client, err := NewAES128GCM(core.ClientSide, key)
if err != nil {
t.Fatalf("NewAES128GCM(ClientSide, key) = %v", err)
}
server, err := NewAES128GCM(core.ServerSide, key)
if err != nil {
t.Fatalf("NewAES128GCM(ServerSide, key) = %v", err)
}
// set counter if provided.
if counter != nil {
if CounterSide(counter) == core.ClientSide {
client.(*aes128gcm).outCounter = CounterFromValue(counter, overflowLenAES128GCM)
server.(*aes128gcm).inCounter = CounterFromValue(counter, overflowLenAES128GCM)
} else {
server.(*aes128gcm).outCounter = CounterFromValue(counter, overflowLenAES128GCM)
client.(*aes128gcm).inCounter = CounterFromValue(counter, overflowLenAES128GCM)
}
}
return client, server
}
func testGCMEncryptionDecryption(sender ALTSRecordCrypto, receiver ALTSRecordCrypto, test *cryptoTestVector, withCounter bool, t *testing.T) {
// Ciphertext is: counter + encrypted text + tag.
ciphertext := []byte(nil)
if withCounter {
ciphertext = append(ciphertext, test.counter...)
}
ciphertext = append(ciphertext, test.ciphertext...)
ciphertext = append(ciphertext, test.tag...)
// Decrypt.
if got, err := receiver.Decrypt(nil, ciphertext); err != nil || !bytes.Equal(got, test.plaintext) {
t.Errorf("key=%v\ncounter=%v\ntag=%v\nciphertext=%v\nDecrypt = %v, %v\nwant: %v",
test.key, test.counter, test.tag, test.ciphertext, got, err, test.plaintext)
}
// Encrypt.
var dst []byte
if test.allocateDst {
dst = make([]byte, len(test.plaintext)+sender.EncryptionOverhead())
}
if got, err := sender.Encrypt(dst[:0], test.plaintext); err != nil || !bytes.Equal(got, ciphertext) {
t.Errorf("key=%v\ncounter=%v\nplaintext=%v\nEncrypt = %v, %v\nwant: %v",
test.key, test.counter, test.plaintext, got, err, ciphertext)
}
}
// Test encrypt and decrypt using test vectors for aes128gcm.
func TestAES128GCMEncrypt(t *testing.T) {
for _, test := range []cryptoTestVector{
{
key: dehex("11754cd72aec309bf52f7687212e8957"),
counter: dehex("3c819d9a9bed087615030b65"),
plaintext: nil,
ciphertext: nil,
tag: dehex("250327c674aaf477aef2675748cf6971"),
allocateDst: false,
},
{
key: dehex("ca47248ac0b6f8372a97ac43508308ed"),
counter: dehex("ffd2b598feabc9019262d2be"),
plaintext: nil,
ciphertext: nil,
tag: dehex("60d20404af527d248d893ae495707d1a"),
allocateDst: false,
},
{
key: dehex("7fddb57453c241d03efbed3ac44e371c"),
counter: dehex("ee283a3fc75575e33efd4887"),
plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"),
ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"),
tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"),
allocateDst: false,
},
{
key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"),
counter: dehex("54cc7dc2c37ec006bcc6d1da"),
plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"),
ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"),
tag: dehex("2b4401346697138c7a4891ee59867d0c"),
allocateDst: false,
},
{
key: dehex("11754cd72aec309bf52f7687212e8957"),
counter: dehex("3c819d9a9bed087615030b65"),
plaintext: nil,
ciphertext: nil,
tag: dehex("250327c674aaf477aef2675748cf6971"),
allocateDst: true,
},
{
key: dehex("ca47248ac0b6f8372a97ac43508308ed"),
counter: dehex("ffd2b598feabc9019262d2be"),
plaintext: nil,
ciphertext: nil,
tag: dehex("60d20404af527d248d893ae495707d1a"),
allocateDst: true,
},
{
key: dehex("7fddb57453c241d03efbed3ac44e371c"),
counter: dehex("ee283a3fc75575e33efd4887"),
plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"),
ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"),
tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"),
allocateDst: true,
},
{
key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"),
counter: dehex("54cc7dc2c37ec006bcc6d1da"),
plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"),
ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"),
tag: dehex("2b4401346697138c7a4891ee59867d0c"),
allocateDst: true,
},
} {
// Test encryption and decryption for aes128gcm.
client, server := getGCMCryptoPair(test.key, test.counter, t)
if CounterSide(test.counter) == core.ClientSide {
testGCMEncryptionDecryption(client, server, &test, false, t)
} else {
testGCMEncryptionDecryption(server, client, &test, false, t)
}
}
}
func testGCMEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) {
// Encrypt.
const plaintext = "This is plaintext."
var err error
buf := []byte(plaintext)
buf, err = client.Encrypt(buf[:0], buf)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext))
}
// Encrypt a second message.
const plaintext2 = "This is a second plaintext."
buf2 := []byte(plaintext2)
buf2, err = client.Encrypt(buf2[:0], buf2)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext2))
}
// Decryption fails: cannot decrypt second message before first.
if got, err := server.Decrypt(nil, buf2); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext2), "\n",
" Ciphertext:", buf2, "\n",
" Decrypted plaintext:", got)
}
// Decryption fails: wrong counter space.
if got, err := client.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
// Decrypt first message.
ciphertext := append([]byte(nil), buf...)
buf, err = server.Decrypt(buf[:0], buf)
if err != nil || string(buf) != plaintext {
t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", ciphertext, "\n",
" Decryption error:", err, "\n",
" Decrypted plaintext:", buf)
}
// Decryption fails: replay attack.
if got, err := server.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
}
// Test encrypt and decrypt on roundtrip messages for aes128gcm.
func TestAES128GCMEncryptRoundtrip(t *testing.T) {
// Test for aes128gcm.
key := make([]byte, 16)
client, server := getGCMCryptoPair(key, nil, t)
testGCMEncryptRoundtrip(client, server, t)
}

View File

@@ -0,0 +1,116 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"crypto/cipher"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
// Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in
// each direction).
overflowLenAES128GCMRekey = 8
nonceLen = 12
aeadKeyLen = 16
kdfKeyLen = 32
kdfCounterOffset = 2
kdfCounterLen = 6
sizeUint64 = 8
)
// aes128gcmRekey is the struct that holds necessary information for ALTS record.
// The counter value is NOT included in the payload during the encryption and
// decryption operations.
type aes128gcmRekey struct {
// inCounter is used in ALTS record to check that incoming counters are
// as expected, since ALTS record guarantees that messages are unwrapped
// in the same order that the peer wrapped them.
inCounter Counter
outCounter Counter
inAEAD cipher.AEAD
outAEAD cipher.AEAD
}
// NewAES128GCMRekey creates an instance that uses aes128gcm with rekeying
// for ALTS record. The key argument should be 44 bytes, the first 32 bytes
// are used as a key for HKDF-expand and the remainining 12 bytes are used
// as a random mask for the counter.
func NewAES128GCMRekey(side core.Side, key []byte) (ALTSRecordCrypto, error) {
inCounter := NewInCounter(side, overflowLenAES128GCMRekey)
outCounter := NewOutCounter(side, overflowLenAES128GCMRekey)
inAEAD, err := newRekeyAEAD(key)
if err != nil {
return nil, err
}
outAEAD, err := newRekeyAEAD(key)
if err != nil {
return nil, err
}
return &aes128gcmRekey{
inCounter,
outCounter,
inAEAD,
outAEAD,
}, nil
}
// Encrypt is the encryption function. dst can contain bytes at the beginning of
// the ciphertext that will not be encrypted but will be authenticated. If dst
// has enough capacity to hold these bytes, the ciphertext and the tag, no
// allocation and copy operations will be performed. dst and plaintext do not
// overlap.
func (s *aes128gcmRekey) Encrypt(dst, plaintext []byte) ([]byte, error) {
// If we need to allocate an output buffer, we want to include space for
// GCM tag to avoid forcing ALTS record to reallocate as well.
dlen := len(dst)
dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize)
seq, err := s.outCounter.Value()
if err != nil {
return nil, err
}
data := out[:len(plaintext)]
copy(data, plaintext) // data may alias plaintext
// Seal appends the ciphertext and the tag to its first argument and
// returns the updated slice. However, SliceForAppend above ensures that
// dst has enough capacity to avoid a reallocation and copy due to the
// append.
dst = s.outAEAD.Seal(dst[:dlen], seq, data, nil)
s.outCounter.Inc()
return dst, nil
}
func (s *aes128gcmRekey) EncryptionOverhead() int {
return GcmTagSize
}
func (s *aes128gcmRekey) Decrypt(dst, ciphertext []byte) ([]byte, error) {
seq, err := s.inCounter.Value()
if err != nil {
return nil, err
}
plaintext, err := s.inAEAD.Open(dst, seq, ciphertext, nil)
if err != nil {
return nil, ErrAuth
}
s.inCounter.Inc()
return plaintext, nil
}

View File

@@ -0,0 +1,117 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
// cryptoTestVector is struct for a rekey test vector
type rekeyTestVector struct {
key, nonce, plaintext, ciphertext []byte
}
// getGCMCryptoPair outputs a client/server pair on aes128gcmRekey.
func getRekeyCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) {
client, err := NewAES128GCMRekey(core.ClientSide, key)
if err != nil {
t.Fatalf("NewAES128GCMRekey(ClientSide, key) = %v", err)
}
server, err := NewAES128GCMRekey(core.ServerSide, key)
if err != nil {
t.Fatalf("NewAES128GCMRekey(ServerSide, key) = %v", err)
}
// set counter if provided.
if counter != nil {
if CounterSide(counter) == core.ClientSide {
client.(*aes128gcmRekey).outCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
server.(*aes128gcmRekey).inCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
} else {
server.(*aes128gcmRekey).outCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
client.(*aes128gcmRekey).inCounter = CounterFromValue(counter, overflowLenAES128GCMRekey)
}
}
return client, server
}
func testRekeyEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) {
// Encrypt.
const plaintext = "This is plaintext."
var err error
buf := []byte(plaintext)
buf, err = client.Encrypt(buf[:0], buf)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext))
}
// Encrypt a second message.
const plaintext2 = "This is a second plaintext."
buf2 := []byte(plaintext2)
buf2, err = client.Encrypt(buf2[:0], buf2)
if err != nil {
t.Fatal("Encrypting with client-side context: unexpected error", err, "\n",
"Plaintext:", []byte(plaintext2))
}
// Decryption fails: cannot decrypt second message before first.
if got, err := server.Decrypt(nil, buf2); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext2), "\n",
" Ciphertext:", buf2, "\n",
" Decrypted plaintext:", got)
}
// Decryption fails: wrong counter space.
if got, err := client.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
// Decrypt first message.
ciphertext := append([]byte(nil), buf...)
buf, err = server.Decrypt(buf[:0], buf)
if err != nil || string(buf) != plaintext {
t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", ciphertext, "\n",
" Decryption error:", err, "\n",
" Decrypted plaintext:", buf)
}
// Decryption fails: replay attack.
if got, err := server.Decrypt(nil, buf); err == nil {
t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n",
" Original plaintext:", []byte(plaintext), "\n",
" Ciphertext:", buf, "\n",
" Decrypted plaintext:", got)
}
}
// Test encrypt and decrypt on roundtrip messages for aes128gcmRekey.
func TestAES128GCMRekeyEncryptRoundtrip(t *testing.T) {
// Test for aes128gcmRekey.
key := make([]byte, 44)
client, server := getRekeyCryptoPair(key, nil, t)
testRekeyEncryptRoundtrip(client, server, t)
}

View File

@@ -0,0 +1,70 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"encoding/binary"
"errors"
"fmt"
)
const (
// GcmTagSize is the GCM tag size is the difference in length between
// plaintext and ciphertext. From crypto/cipher/gcm.go in Go crypto
// library.
GcmTagSize = 16
)
// ErrAuth occurs on authentication failure.
var ErrAuth = errors.New("message authentication failed")
// SliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func SliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return head, tail
}
// ParseFramedMsg parse the provided buffer and returns a frame of the format
// msgLength+msg and any remaining bytes in that buffer.
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
if len(b) < MsgLenFieldSize {
return nil, b, nil
}
msgLenField := b[:MsgLenFieldSize]
length := binary.LittleEndian.Uint32(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
if len(b) < int(length)+4 { // account for the first 4 msg length bytes.
// Frame is not complete yet.
return nil, b, nil
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}

View File

@@ -0,0 +1,62 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"errors"
)
const counterLen = 12
var (
errInvalidCounter = errors.New("invalid counter")
)
// Counter is a 96-bit, little-endian counter.
type Counter struct {
value [counterLen]byte
invalid bool
overflowLen int
}
// Value returns the current value of the counter as a byte slice.
func (c *Counter) Value() ([]byte, error) {
if c.invalid {
return nil, errInvalidCounter
}
return c.value[:], nil
}
// Inc increments the counter and checks for overflow.
func (c *Counter) Inc() {
// If the counter is already invalid, there is no need to increase it.
if c.invalid {
return
}
i := 0
for ; i < c.overflowLen; i++ {
c.value[i]++
if c.value[i] != 0 {
break
}
}
if i == c.overflowLen {
c.invalid = true
}
}

View File

@@ -0,0 +1,141 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
const (
testOverflowLen = 5
)
func TestCounterSides(t *testing.T) {
for _, side := range []core.Side{core.ClientSide, core.ServerSide} {
outCounter := NewOutCounter(side, testOverflowLen)
inCounter := NewInCounter(side, testOverflowLen)
for i := 0; i < 1024; i++ {
value, _ := outCounter.Value()
if g, w := CounterSide(value), side; g != w {
t.Errorf("after %d iterations, CounterSide(outCounter.Value()) = %v, want %v", i, g, w)
break
}
value, _ = inCounter.Value()
if g, w := CounterSide(value), side; g == w {
t.Errorf("after %d iterations, CounterSide(inCounter.Value()) = %v, want %v", i, g, w)
break
}
outCounter.Inc()
inCounter.Inc()
}
}
}
func TestCounterInc(t *testing.T) {
for _, test := range []struct {
counter []byte
want []byte
}{
{
counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80},
want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80},
},
{
counter: []byte{0xff, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0x42, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: []byte{0x43, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
{
counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
},
} {
c := CounterFromValue(test.counter, overflowLenAES128GCM)
c.Inc()
value, _ := c.Value()
if g, w := value, test.want; !bytes.Equal(g, w) || c.invalid {
t.Errorf("counter(%v).Inc() =\n%v, want\n%v", test.counter, g, w)
}
}
}
func TestRolloverCounter(t *testing.T) {
for _, test := range []struct {
desc string
value []byte
overflowLen int
}{
{
desc: "testing overflow without rekeying 1",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80},
overflowLen: 5,
},
{
desc: "testing overflow without rekeying 2",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
overflowLen: 5,
},
{
desc: "testing overflow for rekeying mode 1",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x80},
overflowLen: 8,
},
{
desc: "testing overflow for rekeying mode 2",
value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00},
overflowLen: 8,
},
} {
c := CounterFromValue(test.value, overflowLenAES128GCM)
// First Inc() + Value() should work.
c.Inc()
_, err := c.Value()
if err != nil {
t.Errorf("%v: first Inc() + Value() unexpectedly failed: %v, want <nil> error", test.desc, err)
}
// Second Inc() + Value() should fail.
c.Inc()
_, err = c.Value()
if err != errInvalidCounter {
t.Errorf("%v: second Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter)
}
// Third Inc() + Value() should also fail because the counter is
// already in an invalid state.
c.Inc()
_, err = c.Value()
if err != errInvalidCounter {
t.Errorf("%v: Third Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter)
}
}
}

View File

@@ -0,0 +1,271 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package conn contains an implementation of a secure channel created by gRPC
// handshakers.
package conn
import (
"encoding/binary"
"fmt"
"math"
"net"
core "google.golang.org/grpc/credentials/alts/internal"
)
// ALTSRecordCrypto is the interface for gRPC ALTS record protocol.
type ALTSRecordCrypto interface {
// Encrypt encrypts the plaintext and computes the tag (if any) of dst
// and plaintext, dst and plaintext do not overlap.
Encrypt(dst, plaintext []byte) ([]byte, error)
// EncryptionOverhead returns the tag size (if any) in bytes.
EncryptionOverhead() int
// Decrypt decrypts ciphertext and verify the tag (if any). dst and
// ciphertext may alias exactly or not at all. To reuse ciphertext's
// storage for the decrypted output, use ciphertext[:0] as dst.
Decrypt(dst, ciphertext []byte) ([]byte, error)
}
// ALTSRecordFunc is a function type for factory functions that create
// ALTSRecordCrypto instances.
type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error)
const (
// MsgLenFieldSize is the byte size of the frame length field of a
// framed message.
MsgLenFieldSize = 4
// The byte size of the message type field of a framed message.
msgTypeFieldSize = 4
// The bytes size limit for a ALTS record message.
altsRecordLengthLimit = 1024 * 1024 // 1 MiB
// The default bytes size of a ALTS record message.
altsRecordDefaultLength = 4 * 1024 // 4KiB
// Message type value included in ALTS record framing.
altsRecordMsgType = uint32(0x06)
// The initial write buffer size.
altsWriteBufferInitialSize = 32 * 1024 // 32KiB
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
)
var (
protocols = make(map[string]ALTSRecordFunc)
)
// RegisterProtocol register a ALTS record encryption protocol.
func RegisterProtocol(protocol string, f ALTSRecordFunc) error {
if _, ok := protocols[protocol]; ok {
return fmt.Errorf("protocol %v is already registered", protocol)
}
protocols[protocol] = f
return nil
}
// conn represents a secured connection. It implements the net.Conn interface.
type conn struct {
net.Conn
crypto ALTSRecordCrypto
// buf holds data that has been read from the connection and decrypted,
// but has not yet been returned by Read.
buf []byte
payloadLengthLimit int
// protected holds data read from the network but have not yet been
// decrypted. This data might not compose a complete frame.
protected []byte
// writeBuf is a buffer used to contain encrypted frames before being
// written to the network.
writeBuf []byte
// nextFrame stores the next frame (in protected buffer) info.
nextFrame []byte
// overhead is the calculated overhead of each frame.
overhead int
}
// NewConn creates a new secure channel instance given the other party role and
// handshaking result.
func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) {
newCrypto := protocols[recordProtocol]
if newCrypto == nil {
return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
}
crypto, err := newCrypto(side, key)
if err != nil {
return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
protected = make([]byte, 0, 2*altsRecordDefaultLength-1)
}
altsConn := &conn{
Conn: c,
crypto: crypto,
payloadLengthLimit: payloadLengthLimit,
protected: protected,
writeBuf: make([]byte, altsWriteBufferInitialSize),
nextFrame: protected,
overhead: overhead,
}
return altsConn, nil
}
// Read reads and decrypts a frame from the underlying connection, and copies the
// decrypted payload into b. If the size of the payload is greater than len(b),
// Read retains the remaining bytes in an internal buffer, and subsequent calls
// to Read will read from this buffer until it is exhausted.
func (p *conn) Read(b []byte) (n int, err error) {
if len(p.buf) == 0 {
var framedMsg []byte
framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
if err != nil {
return n, err
}
// Check whether the next frame to be decrypted has been
// completely received yet.
if len(framedMsg) == 0 {
copy(p.protected, p.nextFrame)
p.protected = p.protected[:len(p.nextFrame)]
// Always copy next incomplete frame to the beginning of
// the protected buffer and reset nextFrame to it.
p.nextFrame = p.protected
}
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if len(p.protected) == cap(p.protected) {
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
copy(tmp, p.protected)
p.protected = tmp
}
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
if err != nil {
return 0, err
}
p.protected = p.protected[:len(p.protected)+n]
framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit)
if err != nil {
return 0, err
}
}
// Now we have a complete frame, decrypted it.
msg := framedMsg[MsgLenFieldSize:]
msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize])
if msgType&0xff != altsRecordMsgType {
return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v",
msgType, altsRecordMsgType)
}
ciphertext := msg[msgTypeFieldSize:]
// Decrypt requires that if the dst and ciphertext alias, they
// must alias exactly. Code here used to use msg[:0], but msg
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
// ciphertext, so they alias inexactly. Using ciphertext[:0]
// arranges the appropriate aliasing without needing to copy
// ciphertext or use a separate destination buffer. For more info
// check: https://golang.org/pkg/crypto/cipher/#AEAD.
p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext)
if err != nil {
return 0, err
}
}
n = copy(b, p.buf)
p.buf = p.buf[n:]
return n, nil
}
// Write encrypts, frames, and writes bytes from b to the underlying connection.
func (p *conn) Write(b []byte) (n int, err error) {
n = len(b)
// Calculate the output buffer size with framing and encryption overhead.
numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit)))
size := len(b) + numOfFrames*p.overhead
// If writeBuf is too small, increase its size up to the maximum size.
partialBSize := len(b)
if size > altsWriteBufferMaxSize {
size = altsWriteBufferMaxSize
const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength
partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit
}
if len(p.writeBuf) < size {
p.writeBuf = make([]byte, size)
}
for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize {
partialBEnd := partialBStart + partialBSize
if partialBEnd > len(b) {
partialBEnd = len(b)
}
partialB := b[partialBStart:partialBEnd]
writeBufIndex := 0
for len(partialB) > 0 {
payloadLen := len(partialB)
if payloadLen > p.payloadLengthLimit {
payloadLen = p.payloadLengthLimit
}
buf := partialB[:payloadLen]
partialB = partialB[payloadLen:]
// Write buffer contains: length, type, payload, and tag
// if any.
// 1. Fill in type field.
msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:]
binary.LittleEndian.PutUint32(msg, altsRecordMsgType)
// 2. Encrypt the payload and create a tag if any.
msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf)
if err != nil {
return n, err
}
// 3. Fill in the size field.
binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg)))
// 4. Increase writeBufIndex.
writeBufIndex += len(buf) + p.overhead
}
nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex])
if err != nil {
// We need to calculate the actual data size that was
// written. This means we need to remove header,
// encryption overheads, and any partially-written
// frame data.
numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength)))
return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err
}
}
return n, nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,274 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"math"
"net"
"reflect"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
)
var (
nextProtocols = []string{"ALTSRP_GCM_AES128"}
altsRecordFuncs = map[string]ALTSRecordFunc{
// ALTS handshaker protocols.
"ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
return NewAES128GCM(s, keyData)
},
}
)
func init() {
for protocol, f := range altsRecordFuncs {
if err := RegisterProtocol(protocol, f); err != nil {
panic(err)
}
}
}
// testConn mimics a net.Conn to the peer.
type testConn struct {
net.Conn
in *bytes.Buffer
out *bytes.Buffer
}
func (c *testConn) Read(b []byte) (n int, err error) {
return c.in.Read(b)
}
func (c *testConn) Write(b []byte) (n int, err error) {
return c.out.Write(b)
}
func (c *testConn) Close() error {
return nil
}
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn {
key := []byte{
// 16 arbitrary bytes.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
tc := testConn{
in: in,
out: out,
}
c, err := NewConn(&tc, side, np, key, nil)
if err != nil {
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
}
return c.(*conn)
}
func newConnPair(np string) (client, server *conn) {
clientBuf := new(bytes.Buffer)
serverBuf := new(bytes.Buffer)
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np)
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np)
return clientConn, serverConn
}
func testPingPong(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
clientMsg := []byte("Client Message")
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
}
rcvClientMsg := make([]byte, len(clientMsg))
if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
}
if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
}
serverMsg := []byte("Server Message")
if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil {
t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg))
}
rcvServerMsg := make([]byte, len(serverMsg))
if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil {
t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg))
}
if !reflect.DeepEqual(serverMsg, rcvServerMsg) {
t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg)
}
}
func TestPingPong(t *testing.T) {
for _, np := range nextProtocols {
testPingPong(t, np)
}
}
func testSmallReadBuffer(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
msg := []byte("Very Important Message")
if n, err := clientConn.Write(msg); err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
n := 2 // Arbitrary index to break rcvMsg in two.
rcvMsg1 := rcvMsg[:n]
rcvMsg2 := rcvMsg[n:]
if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1))
}
if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg)
}
}
func TestSmallReadBuffer(t *testing.T) {
for _, np := range nextProtocols {
testSmallReadBuffer(t, np)
}
}
func testLargeMsg(t *testing.T, np string) {
clientConn, serverConn := newConnPair(np)
// msgLen is such that the length in the framing is larger than the
// default size of one frame.
msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
msg := make([]byte, msgLen)
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
}
}
func TestLargeMsg(t *testing.T) {
for _, np := range nextProtocols {
testLargeMsg(t, np)
}
}
func testIncorrectMsgType(t *testing.T, np string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize)
wrongMsgType := uint32(0x22)
binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
in := bytes.NewBuffer(framedMsg)
c := newTestALTSRecordConn(in, nil, core.ClientSide, np)
b := make([]byte, 1)
if n, err := c.Read(b); n != 0 || err == nil {
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
}
}
func TestIncorrectMsgType(t *testing.T) {
for _, np := range nextProtocols {
testIncorrectMsgType(t, np)
}
}
func testFrameTooLarge(t *testing.T, np string) {
buf := new(bytes.Buffer)
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np)
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np)
// payloadLen is such that the length in the framing is larger than
// allowed in one frame.
payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
payload := make([]byte, payloadLen)
c, err := clientConn.crypto.Encrypt(nil, payload)
if err != nil {
t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err))
}
msgLen := msgTypeFieldSize + len(c)
framedMsg := make([]byte, MsgLenFieldSize+msgLen)
binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c)))
msg := framedMsg[MsgLenFieldSize:]
binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType)
copy(msg[msgTypeFieldSize:], c)
if _, err = buf.Write(framedMsg); err != nil {
t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err))
}
b := make([]byte, 1)
if n, err := serverConn.Read(b); n != 0 || err == nil {
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit))
}
}
func TestFrameTooLarge(t *testing.T) {
for _, np := range nextProtocols {
testFrameTooLarge(t, np)
}
}
func testWriteLargeData(t *testing.T, np string) {
// Test sending and receiving messages larger than the maximum write
// buffer size.
clientConn, serverConn := newConnPair(np)
// Message size is intentionally chosen to not be multiple of
// payloadLengthLimtit.
msgSize := altsWriteBufferMaxSize + (100 * 1024)
clientMsg := make([]byte, msgSize)
for i := 0; i < msgSize; i++ {
clientMsg[i] = 0xAA
}
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
}
// We need to keep reading until the entire message is received. The
// reason we set all bytes of the message to a value other than zero is
// to avoid ambiguous zero-init value of rcvClientMsg buffer and the
// actual received data.
rcvClientMsg := make([]byte, 0, msgSize)
numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit)))
for i := 0; i < numberOfExpectedFrames; i++ {
expectedRcvSize := serverConn.payloadLengthLimit
if i == numberOfExpectedFrames-1 {
// Last frame might be smaller.
expectedRcvSize = msgSize % serverConn.payloadLengthLimit
}
tmpBuf := make([]byte, expectedRcvSize)
if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil {
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf))
}
rcvClientMsg = append(rcvClientMsg, tmpBuf...)
}
if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
}
}
func TestWriteLargeData(t *testing.T) {
for _, np := range nextProtocols {
testWriteLargeData(t, np)
}
}

View File

@@ -0,0 +1,63 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import core "google.golang.org/grpc/credentials/alts/internal"
// NewOutCounter returns an outgoing counter initialized to the starting sequence
// number for the client/server side of a connection.
func NewOutCounter(s core.Side, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
if s == core.ServerSide {
// Server counters in ALTS record have the little-endian high bit
// set.
c.value[counterLen-1] = 0x80
}
return
}
// NewInCounter returns an incoming counter initialized to the starting sequence
// number for the client/server side of a connection. This is used in ALTS record
// to check that incoming counters are as expected, since ALTS record guarantees
// that messages are unwrapped in the same order that the peer wrapped them.
func NewInCounter(s core.Side, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
if s == core.ClientSide {
// Server counters in ALTS record have the little-endian high bit
// set.
c.value[counterLen-1] = 0x80
}
return
}
// CounterFromValue creates a new counter given an initial value.
func CounterFromValue(value []byte, overflowLen int) (c Counter) {
c.overflowLen = overflowLen
copy(c.value[:], value)
return
}
// CounterSide returns the connection side (client/server) a sequence counter is
// associated with.
func CounterSide(c []byte) core.Side {
if c[counterLen-1]&0x80 == 0x80 {
return core.ServerSide
}
return core.ClientSide
}

View File

@@ -0,0 +1,365 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package handshaker provides ALTS handshaking functionality for GCP.
package handshaker
import (
"errors"
"fmt"
"io"
"net"
"sync"
"golang.org/x/net/context"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
core "google.golang.org/grpc/credentials/alts/internal"
"google.golang.org/grpc/credentials/alts/internal/authinfo"
"google.golang.org/grpc/credentials/alts/internal/conn"
altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
)
const (
// The maximum byte size of receive frames.
frameLimit = 64 * 1024 // 64 KB
rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
// maxPendingHandshakes represents the maximum number of concurrent
// handshakes.
maxPendingHandshakes = 100
)
var (
hsProtocol = altspb.HandshakeProtocol_ALTS
appProtocols = []string{"grpc"}
recordProtocols = []string{rekeyRecordProtocolName}
keyLength = map[string]int{
rekeyRecordProtocolName: 44,
}
altsRecordFuncs = map[string]conn.ALTSRecordFunc{
// ALTS handshaker protocols.
rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
return conn.NewAES128GCMRekey(s, keyData)
},
}
// control number of concurrent created (but not closed) handshakers.
mu sync.Mutex
concurrentHandshakes = int64(0)
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
)
func init() {
for protocol, f := range altsRecordFuncs {
if err := conn.RegisterProtocol(protocol, f); err != nil {
panic(err)
}
}
}
func acquire(n int64) bool {
mu.Lock()
success := maxPendingHandshakes-concurrentHandshakes >= n
if success {
concurrentHandshakes += n
}
mu.Unlock()
return success
}
func release(n int64) {
mu.Lock()
concurrentHandshakes -= n
if concurrentHandshakes < 0 {
mu.Unlock()
panic("bad release")
}
mu.Unlock()
}
// ClientHandshakerOptions contains the client handshaker options that can
// provided by the caller.
type ClientHandshakerOptions struct {
// ClientIdentity is the handshaker client local identity.
ClientIdentity *altspb.Identity
// TargetName is the server service account name for secure name
// checking.
TargetName string
// TargetServiceAccounts contains a list of expected target service
// accounts. One of these accounts should match one of the accounts in
// the handshaker results. Otherwise, the handshake fails.
TargetServiceAccounts []string
// RPCVersions specifies the gRPC versions accepted by the client.
RPCVersions *altspb.RpcProtocolVersions
}
// ServerHandshakerOptions contains the server handshaker options that can
// provided by the caller.
type ServerHandshakerOptions struct {
// RPCVersions specifies the gRPC versions accepted by the server.
RPCVersions *altspb.RpcProtocolVersions
}
// DefaultClientHandshakerOptions returns the default client handshaker options.
func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
return &ClientHandshakerOptions{}
}
// DefaultServerHandshakerOptions returns the default client handshaker options.
func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
return &ServerHandshakerOptions{}
}
// TODO: add support for future local and remote endpoint in both client options
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.
// altsHandshaker is used to complete a ALTS handshaking between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
// RPC stream used to access the ALTS Handshaker service.
stream altsgrpc.HandshakerService_DoHandshakeClient
// the connection to the peer.
conn net.Conn
// client handshake options.
clientOpts *ClientHandshakerOptions
// server handshake options.
serverOpts *ServerHandshakerOptions
// defines the side doing the handshake, client or server.
side core.Side
}
// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
conn: c,
clientOpts: opts,
side: core.ClientSide,
}, nil
}
// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
conn: c,
serverOpts: opts,
side: core.ServerSide,
}, nil
}
// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire(1) {
return nil, nil, errDropped
}
defer release(1)
if h.side != core.ClientSide {
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
}
// Create target identities from service account list.
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
for _, account := range h.clientOpts.TargetServiceAccounts {
targetIdentities = append(targetIdentities, &altspb.Identity{
IdentityOneof: &altspb.Identity_ServiceAccount{
ServiceAccount: account,
},
})
}
req := &altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_ClientStart{
ClientStart: &altspb.StartClientHandshakeReq{
HandshakeSecurityProtocol: hsProtocol,
ApplicationProtocols: appProtocols,
RecordProtocols: recordProtocols,
TargetIdentities: targetIdentities,
LocalIdentity: h.clientOpts.ClientIdentity,
TargetName: h.clientOpts.TargetName,
RpcVersions: h.clientOpts.RPCVersions,
},
},
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err
}
authInfo := authinfo.New(result)
return conn, authInfo, nil
}
// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire(1) {
return nil, nil, errDropped
}
defer release(1)
if h.side != core.ServerSide {
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
}
p := make([]byte, frameLimit)
n, err := h.conn.Read(p)
if err != nil {
return nil, nil, err
}
// Prepare server parameters.
// TODO: currently only ALTS parameters are provided. Might need to use
// more options in the future.
params := make(map[int32]*altspb.ServerHandshakeParameters)
params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
RecordProtocols: recordProtocols,
}
req := &altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_ServerStart{
ServerStart: &altspb.StartServerHandshakeReq{
ApplicationProtocols: appProtocols,
HandshakeParameters: params,
InBytes: p[:n],
RpcVersions: h.serverOpts.RPCVersions,
},
},
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err
}
authInfo := authinfo.New(result)
return conn, authInfo, nil
}
func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
resp, err := h.accessHandshakerService(req)
if err != nil {
return nil, nil, err
}
// Check of the returned status is an error.
if resp.GetStatus() != nil {
if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
}
}
var extra []byte
if req.GetServerStart() != nil {
extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
}
result, extra, err := h.processUntilDone(resp, extra)
if err != nil {
return nil, nil, err
}
// The handshaker returns a 128 bytes key. It should be truncated based
// on the returned record protocol.
keyLen, ok := keyLength[result.RecordProtocol]
if !ok {
return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
}
sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
if err != nil {
return nil, nil, err
}
return sc, result, nil
}
func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
if err := h.stream.Send(req); err != nil {
return nil, err
}
resp, err := h.stream.Recv()
if err != nil {
return nil, err
}
return resp, nil
}
// processUntilDone processes the handshake until the handshaker service returns
// the results. Handshaker service takes care of frame parsing, so we read
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
for {
if len(resp.OutFrames) > 0 {
if _, err := h.conn.Write(resp.OutFrames); err != nil {
return nil, nil, err
}
}
if resp.Result != nil {
return resp.Result, extra, nil
}
buf := make([]byte, frameLimit)
n, err := h.conn.Read(buf)
if err != nil && err != io.EOF {
return nil, nil, err
}
// If there is nothing to send to the handshaker service, and
// nothing is received from the peer, then we are stuck.
// This covers the case when the peer is not responding. Note
// that handshaker service connection issues are caught in
// accessHandshakerService before we even get here.
if len(resp.OutFrames) == 0 && n == 0 {
return nil, nil, core.PeerNotRespondingError
}
// Append extra bytes from the previous interaction with the
// handshaker service with the current buffer read from conn.
p := append(extra, buf[:n]...)
resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
ReqOneof: &altspb.HandshakerReq_Next{
Next: &altspb.NextHandshakeMessageReq{
InBytes: p,
},
},
})
if err != nil {
return nil, nil, err
}
// Set extra based on handshaker service response.
if n == 0 {
extra = nil
} else {
extra = buf[resp.GetBytesConsumed():n]
}
}
}
// Close terminates the Handshaker. It should be called when the caller obtains
// the secure connection.
func (h *altsHandshaker) Close() {
h.stream.CloseSend()
}

View File

@@ -0,0 +1,261 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package handshaker
import (
"bytes"
"testing"
"time"
"golang.org/x/net/context"
grpc "google.golang.org/grpc"
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/credentials/alts/internal/testutil"
)
var (
testAppProtocols = []string{"grpc"}
testRecordProtocol = rekeyRecordProtocolName
testKey = []byte{
// 44 arbitrary bytes.
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b,
0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2,
}
testServiceAccount = "test_service_account"
testTargetServiceAccounts = []string{testServiceAccount}
testClientIdentity = &altspb.Identity{
IdentityOneof: &altspb.Identity_Hostname{
Hostname: "i_am_a_client",
},
}
)
// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
type testRPCStream struct {
grpc.ClientStream
t *testing.T
isClient bool
// The resp expected to be returned by Recv(). Make sure this is set to
// the content the test requires before Recv() is invoked.
recvBuf *altspb.HandshakerResp
// false if it is the first access to Handshaker service on Envelope.
first bool
// useful for testing concurrent calls.
delay time.Duration
}
func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
resp := t.recvBuf
t.recvBuf = nil
return resp, nil
}
func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
var resp *altspb.HandshakerResp
if !t.first {
// Generate the bytes to be returned by Recv() for the initial
// handshaking.
t.first = true
if t.isClient {
resp = &altspb.HandshakerResp{
OutFrames: testutil.MakeFrame("ClientInit"),
// Simulate consuming ServerInit.
BytesConsumed: 14,
}
} else {
resp = &altspb.HandshakerResp{
OutFrames: testutil.MakeFrame("ServerInit"),
// Simulate consuming ClientInit.
BytesConsumed: 14,
}
}
} else {
// Add delay to test concurrent calls.
cleanup := stat.Update()
defer cleanup()
time.Sleep(t.delay)
// Generate the response to be returned by Recv() for the
// follow-up handshaking.
result := &altspb.HandshakerResult{
RecordProtocol: testRecordProtocol,
KeyData: testKey,
}
resp = &altspb.HandshakerResp{
Result: result,
// Simulate consuming ClientFinished or ServerFinished.
BytesConsumed: 18,
}
}
t.recvBuf = resp
return nil
}
func (t *testRPCStream) CloseSend() error {
return nil
}
var stat testutil.Stats
func TestClientHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
} {
errc := make(chan error)
stat.Reset()
for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: true,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ServerInit")
f2 := testutil.MakeFrame("ServerFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
chs := &altsHandshaker{
stream: stream,
conn: tc,
clientOpts: &ClientHandshakerOptions{
TargetServiceAccounts: testTargetServiceAccounts,
ClientIdentity: testClientIdentity,
},
side: core.ClientSide,
}
go func() {
_, context, err := chs.ClientHandshake(context.Background())
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
errc <- err
chs.Close()
}()
}
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}
// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
}
}
}
func TestServerHandshake(t *testing.T) {
for _, testCase := range []struct {
delay time.Duration
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
} {
errc := make(chan error)
stat.Reset()
for i := 0; i < testCase.numberOfHandshakes; i++ {
stream := &testRPCStream{
t: t,
isClient: false,
}
// Preload the inbound frames.
f1 := testutil.MakeFrame("ClientInit")
f2 := testutil.MakeFrame("ClientFinished")
in := bytes.NewBuffer(f1)
in.Write(f2)
out := new(bytes.Buffer)
tc := testutil.NewTestConn(in, out)
shs := &altsHandshaker{
stream: stream,
conn: tc,
serverOpts: DefaultServerHandshakerOptions(),
side: core.ServerSide,
}
go func() {
_, context, err := shs.ServerHandshake(context.Background())
if err == nil && context == nil {
panic("expected non-nil ALTS context")
}
errc <- err
shs.Close()
}()
}
// Ensure all errors are expected.
for i := 0; i < testCase.numberOfHandshakes; i++ {
if err := <-errc; err != nil && err != errDropped {
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
}
}
// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
}
}
}
// testUnresponsiveRPCStream is used for testing the PeerNotResponding case.
type testUnresponsiveRPCStream struct {
grpc.ClientStream
}
func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) {
return &altspb.HandshakerResp{}, nil
}
func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error {
return nil
}
func (t *testUnresponsiveRPCStream) CloseSend() error {
return nil
}
func TestPeerNotResponding(t *testing.T) {
stream := &testUnresponsiveRPCStream{}
chs := &altsHandshaker{
stream: stream,
conn: testutil.NewUnresponsiveTestConn(),
clientOpts: &ClientHandshakerOptions{
TargetServiceAccounts: testTargetServiceAccounts,
ClientIdentity: testClientIdentity,
},
side: core.ClientSide,
}
_, context, err := chs.ClientHandshake(context.Background())
chs.Close()
if context != nil {
t.Error("expected non-nil ALTS context")
}
if got, want := err, core.PeerNotRespondingError; got != want {
t.Errorf("ClientHandshake() = %v, want %v", got, want)
}
}

View File

@@ -0,0 +1,56 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package service manages connections between the VM application and the ALTS
// handshaker service.
package service
import (
"sync"
grpc "google.golang.org/grpc"
)
var (
// hsConn represents a connection to hypervisor handshaker service.
hsConn *grpc.ClientConn
mu sync.Mutex
// hsDialer will be reassigned in tests.
hsDialer = grpc.Dial
)
type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error)
// Dial dials the handshake service in the hypervisor. If a connection has
// already been established, this function returns it. Otherwise, a new
// connection is created.
func Dial(hsAddress string) (*grpc.ClientConn, error) {
mu.Lock()
defer mu.Unlock()
if hsConn == nil {
// Create a new connection to the handshaker service. Note that
// this connection stays open until the application is closed.
var err error
hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
if err != nil {
return nil, err
}
}
return hsConn, nil
}

View File

@@ -0,0 +1,69 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package service
import (
"testing"
grpc "google.golang.org/grpc"
)
const (
// The address is irrelevant in this test.
testAddress = "some_address"
)
func TestDial(t *testing.T) {
defer func() func() {
temp := hsDialer
hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
return &grpc.ClientConn{}, nil
}
return func() {
hsDialer = temp
}
}()
// Ensure that hsConn is nil at first.
hsConn = nil
// First call to Dial, it should create set hsConn.
conn1, err := Dial(testAddress)
if err != nil {
t.Fatalf("first call to Dial failed: %v", err)
}
if conn1 == nil {
t.Fatal("first call to Dial(_)=(nil, _), want not nil")
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
// Second call to Dial should return conn1 above.
conn2, err := Dial(testAddress)
if err != nil {
t.Fatalf("second call to Dial(_) failed: %v", err)
}
if got, want := conn2, conn1; got != want {
t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
}

View File

@@ -0,0 +1,151 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/gcp/altscontext.proto
package grpc_gcp // import "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type AltsContext struct {
// The application protocol negotiated for this connection.
ApplicationProtocol string `protobuf:"bytes,1,opt,name=application_protocol,json=applicationProtocol,proto3" json:"application_protocol,omitempty"`
// The record protocol negotiated for this connection.
RecordProtocol string `protobuf:"bytes,2,opt,name=record_protocol,json=recordProtocol,proto3" json:"record_protocol,omitempty"`
// The security level of the created secure channel.
SecurityLevel SecurityLevel `protobuf:"varint,3,opt,name=security_level,json=securityLevel,proto3,enum=grpc.gcp.SecurityLevel" json:"security_level,omitempty"`
// The peer service account.
PeerServiceAccount string `protobuf:"bytes,4,opt,name=peer_service_account,json=peerServiceAccount,proto3" json:"peer_service_account,omitempty"`
// The local service account.
LocalServiceAccount string `protobuf:"bytes,5,opt,name=local_service_account,json=localServiceAccount,proto3" json:"local_service_account,omitempty"`
// The RPC protocol versions supported by the peer.
PeerRpcVersions *RpcProtocolVersions `protobuf:"bytes,6,opt,name=peer_rpc_versions,json=peerRpcVersions,proto3" json:"peer_rpc_versions,omitempty"`
// Additional attributes of the peer.
PeerAttributes map[string]string `protobuf:"bytes,7,rep,name=peer_attributes,json=peerAttributes,proto3" json:"peer_attributes,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *AltsContext) Reset() { *m = AltsContext{} }
func (m *AltsContext) String() string { return proto.CompactTextString(m) }
func (*AltsContext) ProtoMessage() {}
func (*AltsContext) Descriptor() ([]byte, []int) {
return fileDescriptor_altscontext_f6b7868f9a30497f, []int{0}
}
func (m *AltsContext) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_AltsContext.Unmarshal(m, b)
}
func (m *AltsContext) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_AltsContext.Marshal(b, m, deterministic)
}
func (dst *AltsContext) XXX_Merge(src proto.Message) {
xxx_messageInfo_AltsContext.Merge(dst, src)
}
func (m *AltsContext) XXX_Size() int {
return xxx_messageInfo_AltsContext.Size(m)
}
func (m *AltsContext) XXX_DiscardUnknown() {
xxx_messageInfo_AltsContext.DiscardUnknown(m)
}
var xxx_messageInfo_AltsContext proto.InternalMessageInfo
func (m *AltsContext) GetApplicationProtocol() string {
if m != nil {
return m.ApplicationProtocol
}
return ""
}
func (m *AltsContext) GetRecordProtocol() string {
if m != nil {
return m.RecordProtocol
}
return ""
}
func (m *AltsContext) GetSecurityLevel() SecurityLevel {
if m != nil {
return m.SecurityLevel
}
return SecurityLevel_SECURITY_NONE
}
func (m *AltsContext) GetPeerServiceAccount() string {
if m != nil {
return m.PeerServiceAccount
}
return ""
}
func (m *AltsContext) GetLocalServiceAccount() string {
if m != nil {
return m.LocalServiceAccount
}
return ""
}
func (m *AltsContext) GetPeerRpcVersions() *RpcProtocolVersions {
if m != nil {
return m.PeerRpcVersions
}
return nil
}
func (m *AltsContext) GetPeerAttributes() map[string]string {
if m != nil {
return m.PeerAttributes
}
return nil
}
func init() {
proto.RegisterType((*AltsContext)(nil), "grpc.gcp.AltsContext")
proto.RegisterMapType((map[string]string)(nil), "grpc.gcp.AltsContext.PeerAttributesEntry")
}
func init() {
proto.RegisterFile("grpc/gcp/altscontext.proto", fileDescriptor_altscontext_f6b7868f9a30497f)
}
var fileDescriptor_altscontext_f6b7868f9a30497f = []byte{
// 411 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x92, 0x4d, 0x6f, 0x13, 0x31,
0x10, 0x86, 0xb5, 0x0d, 0x2d, 0xe0, 0x88, 0xb4, 0xb8, 0xa9, 0x58, 0x45, 0x42, 0x8a, 0xb8, 0xb0,
0x5c, 0x76, 0x21, 0x5c, 0x10, 0x07, 0x50, 0x8a, 0x38, 0x20, 0x71, 0x88, 0xb6, 0x12, 0x07, 0x2e,
0x2b, 0x77, 0x3a, 0xb2, 0x2c, 0x5c, 0x8f, 0x35, 0x76, 0x22, 0xf2, 0xb3, 0xf9, 0x07, 0x68, 0xed,
0xcd, 0x07, 0x1f, 0xb7, 0x9d, 0x79, 0x9f, 0x19, 0xbf, 0xb3, 0x33, 0x62, 0xa6, 0xd9, 0x43, 0xa3,
0xc1, 0x37, 0xca, 0xc6, 0x00, 0xe4, 0x22, 0xfe, 0x8c, 0xb5, 0x67, 0x8a, 0x24, 0x1f, 0xf5, 0x5a,
0xad, 0xc1, 0xcf, 0xaa, 0x3d, 0x15, 0x59, 0xb9, 0xe0, 0x89, 0x63, 0x17, 0x10, 0xd6, 0x6c, 0xe2,
0xb6, 0x03, 0xba, 0xbf, 0x27, 0x97, 0x6b, 0x5e, 0xfc, 0x1a, 0x89, 0xf1, 0xd2, 0xc6, 0xf0, 0x29,
0x77, 0x92, 0x6f, 0xc4, 0x54, 0x79, 0x6f, 0x0d, 0xa8, 0x68, 0xc8, 0x75, 0x09, 0x02, 0xb2, 0x65,
0x31, 0x2f, 0xaa, 0xc7, 0xed, 0xe5, 0x91, 0xb6, 0x1a, 0x24, 0xf9, 0x52, 0x9c, 0x33, 0x02, 0xf1,
0xdd, 0x81, 0x3e, 0x49, 0xf4, 0x24, 0xa7, 0xf7, 0xe0, 0x07, 0x31, 0xd9, 0x9b, 0xb0, 0xb8, 0x41,
0x5b, 0x8e, 0xe6, 0x45, 0x35, 0x59, 0x3c, 0xab, 0x77, 0xc6, 0xeb, 0x9b, 0x41, 0xff, 0xda, 0xcb,
0xed, 0x93, 0x70, 0x1c, 0xca, 0xd7, 0x62, 0xea, 0x11, 0xb9, 0x0b, 0xc8, 0x1b, 0x03, 0xd8, 0x29,
0x00, 0x5a, 0xbb, 0x58, 0x3e, 0x48, 0xaf, 0xc9, 0x5e, 0xbb, 0xc9, 0xd2, 0x32, 0x2b, 0x72, 0x21,
0xae, 0x2c, 0x81, 0xb2, 0xff, 0x94, 0x9c, 0xe6, 0x71, 0x92, 0xf8, 0x57, 0xcd, 0x17, 0xf1, 0x34,
0xbd, 0xc2, 0x1e, 0xba, 0x0d, 0x72, 0x30, 0xe4, 0x42, 0x79, 0x36, 0x2f, 0xaa, 0xf1, 0xe2, 0xf9,
0xc1, 0x68, 0xeb, 0x61, 0x37, 0xd7, 0xb7, 0x01, 0x6a, 0xcf, 0xfb, 0xba, 0xd6, 0xc3, 0x2e, 0x21,
0x5b, 0x91, 0x52, 0x9d, 0x8a, 0x91, 0xcd, 0xed, 0x3a, 0x62, 0x28, 0x1f, 0xce, 0x47, 0xd5, 0x78,
0xf1, 0xea, 0xd0, 0xe8, 0xe8, 0xe7, 0xd7, 0x2b, 0x44, 0x5e, 0xee, 0xd9, 0xcf, 0x2e, 0xf2, 0xb6,
0x9d, 0xf8, 0x3f, 0x92, 0xb3, 0xa5, 0xb8, 0xfc, 0x0f, 0x26, 0x2f, 0xc4, 0xe8, 0x07, 0x6e, 0x87,
0x35, 0xf5, 0x9f, 0x72, 0x2a, 0x4e, 0x37, 0xca, 0xae, 0x71, 0x58, 0x46, 0x0e, 0xde, 0x9f, 0xbc,
0x2b, 0xae, 0xad, 0xb8, 0x32, 0x94, 0x1d, 0xf4, 0x47, 0x54, 0x1b, 0x17, 0x91, 0x9d, 0xb2, 0xd7,
0x17, 0x47, 0x66, 0xd2, 0x74, 0xab, 0xe2, 0xfb, 0x47, 0x4d, 0xa4, 0x2d, 0xd6, 0x9a, 0xac, 0x72,
0xba, 0x26, 0xd6, 0x4d, 0x3a, 0x2e, 0x60, 0xbc, 0x43, 0x17, 0x8d, 0xb2, 0x21, 0x9d, 0x62, 0xb3,
0xeb, 0xd2, 0xa4, 0x2b, 0x48, 0x50, 0xa7, 0xc1, 0xdf, 0x9e, 0xa5, 0xf8, 0xed, 0xef, 0x00, 0x00,
0x00, 0xff, 0xff, 0x9b, 0x8c, 0xe4, 0x6a, 0xba, 0x02, 0x00, 0x00,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: grpc/gcp/transport_security_common.proto
package grpc_gcp // import "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
// The security level of the created channel. The list is sorted in increasing
// level of security. This order must always be maintained.
type SecurityLevel int32
const (
SecurityLevel_SECURITY_NONE SecurityLevel = 0
SecurityLevel_INTEGRITY_ONLY SecurityLevel = 1
SecurityLevel_INTEGRITY_AND_PRIVACY SecurityLevel = 2
)
var SecurityLevel_name = map[int32]string{
0: "SECURITY_NONE",
1: "INTEGRITY_ONLY",
2: "INTEGRITY_AND_PRIVACY",
}
var SecurityLevel_value = map[string]int32{
"SECURITY_NONE": 0,
"INTEGRITY_ONLY": 1,
"INTEGRITY_AND_PRIVACY": 2,
}
func (x SecurityLevel) String() string {
return proto.EnumName(SecurityLevel_name, int32(x))
}
func (SecurityLevel) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_transport_security_common_71945991f2c3b4a6, []int{0}
}
// Max and min supported RPC protocol versions.
type RpcProtocolVersions struct {
// Maximum supported RPC version.
MaxRpcVersion *RpcProtocolVersions_Version `protobuf:"bytes,1,opt,name=max_rpc_version,json=maxRpcVersion,proto3" json:"max_rpc_version,omitempty"`
// Minimum supported RPC version.
MinRpcVersion *RpcProtocolVersions_Version `protobuf:"bytes,2,opt,name=min_rpc_version,json=minRpcVersion,proto3" json:"min_rpc_version,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RpcProtocolVersions) Reset() { *m = RpcProtocolVersions{} }
func (m *RpcProtocolVersions) String() string { return proto.CompactTextString(m) }
func (*RpcProtocolVersions) ProtoMessage() {}
func (*RpcProtocolVersions) Descriptor() ([]byte, []int) {
return fileDescriptor_transport_security_common_71945991f2c3b4a6, []int{0}
}
func (m *RpcProtocolVersions) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RpcProtocolVersions.Unmarshal(m, b)
}
func (m *RpcProtocolVersions) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RpcProtocolVersions.Marshal(b, m, deterministic)
}
func (dst *RpcProtocolVersions) XXX_Merge(src proto.Message) {
xxx_messageInfo_RpcProtocolVersions.Merge(dst, src)
}
func (m *RpcProtocolVersions) XXX_Size() int {
return xxx_messageInfo_RpcProtocolVersions.Size(m)
}
func (m *RpcProtocolVersions) XXX_DiscardUnknown() {
xxx_messageInfo_RpcProtocolVersions.DiscardUnknown(m)
}
var xxx_messageInfo_RpcProtocolVersions proto.InternalMessageInfo
func (m *RpcProtocolVersions) GetMaxRpcVersion() *RpcProtocolVersions_Version {
if m != nil {
return m.MaxRpcVersion
}
return nil
}
func (m *RpcProtocolVersions) GetMinRpcVersion() *RpcProtocolVersions_Version {
if m != nil {
return m.MinRpcVersion
}
return nil
}
// RPC version contains a major version and a minor version.
type RpcProtocolVersions_Version struct {
Major uint32 `protobuf:"varint,1,opt,name=major,proto3" json:"major,omitempty"`
Minor uint32 `protobuf:"varint,2,opt,name=minor,proto3" json:"minor,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *RpcProtocolVersions_Version) Reset() { *m = RpcProtocolVersions_Version{} }
func (m *RpcProtocolVersions_Version) String() string { return proto.CompactTextString(m) }
func (*RpcProtocolVersions_Version) ProtoMessage() {}
func (*RpcProtocolVersions_Version) Descriptor() ([]byte, []int) {
return fileDescriptor_transport_security_common_71945991f2c3b4a6, []int{0, 0}
}
func (m *RpcProtocolVersions_Version) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_RpcProtocolVersions_Version.Unmarshal(m, b)
}
func (m *RpcProtocolVersions_Version) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_RpcProtocolVersions_Version.Marshal(b, m, deterministic)
}
func (dst *RpcProtocolVersions_Version) XXX_Merge(src proto.Message) {
xxx_messageInfo_RpcProtocolVersions_Version.Merge(dst, src)
}
func (m *RpcProtocolVersions_Version) XXX_Size() int {
return xxx_messageInfo_RpcProtocolVersions_Version.Size(m)
}
func (m *RpcProtocolVersions_Version) XXX_DiscardUnknown() {
xxx_messageInfo_RpcProtocolVersions_Version.DiscardUnknown(m)
}
var xxx_messageInfo_RpcProtocolVersions_Version proto.InternalMessageInfo
func (m *RpcProtocolVersions_Version) GetMajor() uint32 {
if m != nil {
return m.Major
}
return 0
}
func (m *RpcProtocolVersions_Version) GetMinor() uint32 {
if m != nil {
return m.Minor
}
return 0
}
func init() {
proto.RegisterType((*RpcProtocolVersions)(nil), "grpc.gcp.RpcProtocolVersions")
proto.RegisterType((*RpcProtocolVersions_Version)(nil), "grpc.gcp.RpcProtocolVersions.Version")
proto.RegisterEnum("grpc.gcp.SecurityLevel", SecurityLevel_name, SecurityLevel_value)
}
func init() {
proto.RegisterFile("grpc/gcp/transport_security_common.proto", fileDescriptor_transport_security_common_71945991f2c3b4a6)
}
var fileDescriptor_transport_security_common_71945991f2c3b4a6 = []byte{
// 323 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x91, 0x41, 0x4b, 0x3b, 0x31,
0x10, 0xc5, 0xff, 0x5b, 0xf8, 0xab, 0x44, 0x56, 0xeb, 0x6a, 0x41, 0xc5, 0x83, 0x08, 0x42, 0xf1,
0x90, 0x05, 0xc5, 0xb3, 0xb4, 0xb5, 0x48, 0xa1, 0x6e, 0xeb, 0xb6, 0x16, 0xea, 0x25, 0xc4, 0x18,
0x42, 0x24, 0x9b, 0x09, 0xb3, 0xb1, 0xd4, 0xaf, 0xec, 0xa7, 0x90, 0x4d, 0xbb, 0x14, 0xc1, 0x8b,
0xb7, 0xbc, 0xc7, 0xcc, 0x6f, 0x32, 0xf3, 0x48, 0x5b, 0xa1, 0x13, 0xa9, 0x12, 0x2e, 0xf5, 0xc8,
0x6d, 0xe9, 0x00, 0x3d, 0x2b, 0xa5, 0xf8, 0x40, 0xed, 0x3f, 0x99, 0x80, 0xa2, 0x00, 0x4b, 0x1d,
0x82, 0x87, 0x64, 0xa7, 0xaa, 0xa4, 0x4a, 0xb8, 0x8b, 0xaf, 0x88, 0x1c, 0xe6, 0x4e, 0x8c, 0x2b,
0x5b, 0x80, 0x99, 0x49, 0x2c, 0x35, 0xd8, 0x32, 0x79, 0x24, 0xfb, 0x05, 0x5f, 0x32, 0x74, 0x82,
0x2d, 0x56, 0xde, 0x71, 0x74, 0x1e, 0xb5, 0x77, 0xaf, 0x2f, 0x69, 0xdd, 0x4b, 0x7f, 0xe9, 0xa3,
0xeb, 0x47, 0x1e, 0x17, 0x7c, 0x99, 0x3b, 0xb1, 0x96, 0x01, 0xa7, 0xed, 0x0f, 0x5c, 0xe3, 0x6f,
0x38, 0x6d, 0x37, 0xb8, 0xd3, 0x5b, 0xb2, 0x5d, 0x93, 0x8f, 0xc8, 0xff, 0x82, 0xbf, 0x03, 0x86,
0xef, 0xc5, 0xf9, 0x4a, 0x04, 0x57, 0x5b, 0xc0, 0x30, 0xa5, 0x72, 0x2b, 0x71, 0xf5, 0x44, 0xe2,
0xc9, 0xfa, 0x1e, 0x43, 0xb9, 0x90, 0x26, 0x39, 0x20, 0xf1, 0xa4, 0xdf, 0x7b, 0xce, 0x07, 0xd3,
0x39, 0xcb, 0x46, 0x59, 0xbf, 0xf9, 0x2f, 0x49, 0xc8, 0xde, 0x20, 0x9b, 0xf6, 0x1f, 0x82, 0x37,
0xca, 0x86, 0xf3, 0x66, 0x94, 0x9c, 0x90, 0xd6, 0xc6, 0xeb, 0x64, 0xf7, 0x6c, 0x9c, 0x0f, 0x66,
0x9d, 0xde, 0xbc, 0xd9, 0xe8, 0x2e, 0x49, 0x4b, 0xc3, 0x6a, 0x07, 0x6e, 0x7c, 0x49, 0xb5, 0xf5,
0x12, 0x2d, 0x37, 0xdd, 0xb3, 0x69, 0x9d, 0x41, 0x3d, 0xb2, 0x17, 0x12, 0x08, 0x2b, 0x8e, 0xa3,
0x97, 0x3b, 0x05, 0xa0, 0x8c, 0xa4, 0x0a, 0x0c, 0xb7, 0x8a, 0x02, 0xaa, 0x34, 0xc4, 0x27, 0x50,
0xbe, 0x49, 0xeb, 0x35, 0x37, 0x65, 0x5a, 0x11, 0xd3, 0x9a, 0x98, 0x86, 0xe8, 0x42, 0x11, 0x53,
0xc2, 0xbd, 0x6e, 0x05, 0x7d, 0xf3, 0x1d, 0x00, 0x00, 0xff, 0xff, 0x31, 0x14, 0xb4, 0x11, 0xf6,
0x01, 0x00, 0x00,
}

View File

@@ -0,0 +1,35 @@
#!/bin/bash
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux -o pipefail
TMP=$(mktemp -d)
function finish {
rm -rf "$TMP"
}
trap finish EXIT
pushd "$TMP"
mkdir -p grpc/gcp
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/gcp/altscontext.proto > grpc/gcp/altscontext.proto
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/gcp/handshaker.proto > grpc/gcp/handshaker.proto
curl https://raw.githubusercontent.com/grpc/grpc-proto/master/grpc/gcp/transport_security_common.proto > grpc/gcp/transport_security_common.proto
protoc --go_out=plugins=grpc,paths=source_relative:. -I. grpc/gcp/*.proto
popd
rm -f proto/grpc_gcp/*.pb.go
cp "$TMP"/grpc/gcp/*.pb.go proto/grpc_gcp/

View File

@@ -0,0 +1,125 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package testutil include useful test utilities for the handshaker.
package testutil
import (
"bytes"
"encoding/binary"
"io"
"net"
"sync"
"google.golang.org/grpc/credentials/alts/internal/conn"
)
// Stats is used to collect statistics about concurrent handshake calls.
type Stats struct {
mu sync.Mutex
calls int
MaxConcurrentCalls int
}
// Update updates the statistics by adding one call.
func (s *Stats) Update() func() {
s.mu.Lock()
s.calls++
if s.calls > s.MaxConcurrentCalls {
s.MaxConcurrentCalls = s.calls
}
s.mu.Unlock()
return func() {
s.mu.Lock()
s.calls--
s.mu.Unlock()
}
}
// Reset resets the statistics.
func (s *Stats) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
s.calls = 0
s.MaxConcurrentCalls = 0
}
// testConn mimics a net.Conn to the peer.
type testConn struct {
net.Conn
in *bytes.Buffer
out *bytes.Buffer
}
// NewTestConn creates a new instance of testConn object.
func NewTestConn(in *bytes.Buffer, out *bytes.Buffer) net.Conn {
return &testConn{
in: in,
out: out,
}
}
// Read reads from the in buffer.
func (c *testConn) Read(b []byte) (n int, err error) {
return c.in.Read(b)
}
// Write writes to the out buffer.
func (c *testConn) Write(b []byte) (n int, err error) {
return c.out.Write(b)
}
// Close closes the testConn object.
func (c *testConn) Close() error {
return nil
}
// unresponsiveTestConn mimics a net.Conn for an unresponsive peer. It is used
// for testing the PeerNotResponding case.
type unresponsiveTestConn struct {
net.Conn
}
// NewUnresponsiveTestConn creates a new instance of unresponsiveTestConn object.
func NewUnresponsiveTestConn() net.Conn {
return &unresponsiveTestConn{}
}
// Read reads from the in buffer.
func (c *unresponsiveTestConn) Read(b []byte) (n int, err error) {
return 0, io.EOF
}
// Write writes to the out buffer.
func (c *unresponsiveTestConn) Write(b []byte) (n int, err error) {
return 0, nil
}
// Close closes the TestConn object.
func (c *unresponsiveTestConn) Close() error {
return nil
}
// MakeFrame creates a handshake frame.
func MakeFrame(pl string) []byte {
f := make([]byte, len(pl)+conn.MsgLenFieldSize)
binary.LittleEndian.PutUint32(f, uint32(len(pl)))
copy(f[conn.MsgLenFieldSize:], []byte(pl))
return f
}

141
vendor/google.golang.org/grpc/credentials/alts/utils.go generated vendored Normal file
View File

@@ -0,0 +1,141 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"golang.org/x/net/context"
"google.golang.org/grpc/peer"
)
const (
linuxProductNameFile = "/sys/class/dmi/id/product_name"
windowsCheckCommand = "powershell.exe"
windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS"
powershellOutputFilter = "Manufacturer"
windowsManufacturerRegex = ":(.*)"
)
type platformError string
func (k platformError) Error() string {
return fmt.Sprintf("%s is not supported", string(k))
}
var (
// The following two variables will be reassigned in tests.
runningOS = runtime.GOOS
manufacturerReader = func() (io.Reader, error) {
switch runningOS {
case "linux":
return os.Open(linuxProductNameFile)
case "windows":
cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs)
out, err := cmd.Output()
if err != nil {
return nil, err
}
for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") {
if strings.HasPrefix(line, powershellOutputFilter) {
re := regexp.MustCompile(windowsManufacturerRegex)
name := re.FindString(line)
name = strings.TrimLeft(name, ":")
return strings.NewReader(name), nil
}
}
return nil, errors.New("cannot determine the machine's manufacturer")
default:
return nil, platformError(runningOS)
}
}
vmOnGCP bool
)
// isRunningOnGCP checks whether the local system, without doing a network request is
// running on GCP.
func isRunningOnGCP() bool {
manufacturer, err := readManufacturer()
if err != nil {
log.Fatalf("failure to read manufacturer information: %v", err)
}
name := string(manufacturer)
switch runningOS {
case "linux":
name = strings.TrimSpace(name)
return name == "Google" || name == "Google Compute Engine"
case "windows":
name = strings.Replace(name, " ", "", -1)
name = strings.Replace(name, "\n", "", -1)
name = strings.Replace(name, "\r", "", -1)
return name == "Google"
default:
log.Fatal(platformError(runningOS))
}
return false
}
func readManufacturer() ([]byte, error) {
reader, err := manufacturerReader()
if err != nil {
return nil, err
}
if reader == nil {
return nil, errors.New("got nil reader")
}
manufacturer, err := ioutil.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err)
}
return manufacturer, nil
}
// AuthInfoFromContext extracts the alts.AuthInfo object from the given context,
// if it exists. This API should be used by gRPC server RPC handlers to get
// information about the communicating peer. For client-side, use grpc.Peer()
// CallOption.
func AuthInfoFromContext(ctx context.Context) (AuthInfo, error) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("no Peer found in Context")
}
return AuthInfoFromPeer(p)
}
// AuthInfoFromPeer extracts the alts.AuthInfo object from the given peer, if it
// exists. This API should be used by gRPC clients after obtaining a peer object
// using the grpc.Peer() CallOption.
func AuthInfoFromPeer(p *peer.Peer) (AuthInfo, error) {
altsAuthInfo, ok := p.AuthInfo.(AuthInfo)
if !ok {
return nil, errors.New("no alts.AuthInfo found in Peer")
}
return altsAuthInfo, nil
}

View File

@@ -0,0 +1,139 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package alts
import (
"io"
"strings"
"testing"
"golang.org/x/net/context"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/peer"
)
func TestIsRunningOnGCP(t *testing.T) {
for _, tc := range []struct {
description string
testOS string
testReader io.Reader
out bool
}{
// Linux tests.
{"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false},
{"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true},
{"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true},
{"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true},
// Windows tests.
{"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false},
{"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true},
{"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true},
} {
reverseFunc := setup(tc.testOS, tc.testReader)
if got, want := isRunningOnGCP(), tc.out; got != want {
t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want)
}
reverseFunc()
}
}
func setup(testOS string, testReader io.Reader) func() {
tmpOS := runningOS
tmpReader := manufacturerReader
// Set test OS and reader function.
runningOS = testOS
manufacturerReader = func() (io.Reader, error) {
return testReader, nil
}
return func() {
runningOS = tmpOS
manufacturerReader = tmpReader
}
}
func TestAuthInfoFromContext(t *testing.T) {
ctx := context.Background()
altsAuthInfo := &fakeALTSAuthInfo{}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
}
for _, tc := range []struct {
desc string
ctx context.Context
success bool
out AuthInfo
}{
{
"working case",
peer.NewContext(ctx, p),
true,
altsAuthInfo,
},
} {
authInfo, err := AuthInfoFromContext(tc.ctx)
if got, want := (err == nil), tc.success; got != want {
t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
}
if got, want := authInfo, tc.out; got != want {
t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
}
}
}
func TestAuthInfoFromPeer(t *testing.T) {
altsAuthInfo := &fakeALTSAuthInfo{}
p := &peer.Peer{
AuthInfo: altsAuthInfo,
}
for _, tc := range []struct {
desc string
p *peer.Peer
success bool
out AuthInfo
}{
{
"working case",
p,
true,
altsAuthInfo,
},
} {
authInfo, err := AuthInfoFromPeer(tc.p)
if got, want := (err == nil), tc.success; got != want {
t.Errorf("%v: AuthInfoFromPeer(_)=(err=nil)=%v, want %v", tc.desc, got, want)
}
if got, want := authInfo, tc.out; got != want {
t.Errorf("%v:, AuthInfoFromPeer(_)=(%v, _), want (%v, _)", tc.desc, got, want)
}
}
}
type fakeALTSAuthInfo struct{}
func (*fakeALTSAuthInfo) AuthType() string { return "" }
func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
func (*fakeALTSAuthInfo) RecordProtocol() string { return "" }
func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
return altspb.SecurityLevel_SECURITY_NONE
}
func (*fakeALTSAuthInfo) PeerServiceAccount() string { return "" }
func (*fakeALTSAuthInfo) LocalServiceAccount() string { return "" }
func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }

View File

@@ -31,6 +31,7 @@ import (
"net"
"strings"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
)
@@ -107,6 +108,25 @@ type TransportCredentials interface {
OverrideServerName(string) error
}
// Bundle is a combination of TransportCredentials and PerRPCCredentials.
//
// It also contains a mode switching method, so it can be used as a combination
// of different credential policies.
//
// Bundle cannot be used together with individual TransportCredentials.
// PerRPCCredentials from Bundle will be appended to other PerRPCCredentials.
//
// This API is experimental.
type Bundle interface {
TransportCredentials() TransportCredentials
PerRPCCredentials() PerRPCCredentials
// NewWithMode should make a copy of Bundle, and switch mode. Modifying the
// existing Bundle may cause races.
//
// NewWithMode returns nil if the requested mode is not supported.
NewWithMode(mode string) (Bundle, error)
}
// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
@@ -118,6 +138,18 @@ func (t TLSInfo) AuthType() string {
return "tls"
}
// GetChannelzSecurityValue returns security info requested by channelz.
func (t TLSInfo) GetChannelzSecurityValue() ChannelzSecurityValue {
v := &TLSChannelzSecurityValue{
StandardName: cipherSuiteLookup[t.State.CipherSuite],
}
// Currently there's no way to get LocalCertificate info from tls package.
if len(t.State.PeerCertificates) > 0 {
v.RemoteCertificate = t.State.PeerCertificates[0].Raw
}
return v
}
// tlsCreds is the credentials required for authenticating a connection using TLS.
type tlsCreds struct {
// TLS configuration
@@ -155,7 +187,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
case <-ctx.Done():
return nil, nil, ctx.Err()
}
return conn, TLSInfo{conn.ConnectionState()}, nil
return tlsConn{Conn: conn, rawConn: rawConn}, TLSInfo{conn.ConnectionState()}, nil
}
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
@@ -163,7 +195,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
if err := conn.Handshake(); err != nil {
return nil, nil, err
}
return conn, TLSInfo{conn.ConnectionState()}, nil
return tlsConn{Conn: conn, rawConn: rawConn}, TLSInfo{conn.ConnectionState()}, nil
}
func (c *tlsCreds) Clone() TransportCredentials {
@@ -218,3 +250,63 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error
}
return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
}
// ChannelzSecurityInfo defines the interface that security protocols should implement
// in order to provide security info to channelz.
type ChannelzSecurityInfo interface {
GetSecurityValue() ChannelzSecurityValue
}
// ChannelzSecurityValue defines the interface that GetSecurityValue() return value
// should satisfy. This interface should only be satisfied by *TLSChannelzSecurityValue
// and *OtherChannelzSecurityValue.
type ChannelzSecurityValue interface {
isChannelzSecurityValue()
}
// TLSChannelzSecurityValue defines the struct that TLS protocol should return
// from GetSecurityValue(), containing security info like cipher and certificate used.
type TLSChannelzSecurityValue struct {
StandardName string
LocalCertificate []byte
RemoteCertificate []byte
}
func (*TLSChannelzSecurityValue) isChannelzSecurityValue() {}
// OtherChannelzSecurityValue defines the struct that non-TLS protocol should return
// from GetSecurityValue(), which contains protocol specific security info. Note
// the Value field will be sent to users of channelz requesting channel info, and
// thus sensitive info should better be avoided.
type OtherChannelzSecurityValue struct {
Name string
Value proto.Message
}
func (*OtherChannelzSecurityValue) isChannelzSecurityValue() {}
type tlsConn struct {
*tls.Conn
rawConn net.Conn
}
var cipherSuiteLookup = map[uint16]string{
tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA",
tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA",
tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "TLS_RSA_WITH_AES_128_GCM_SHA256",
tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384",
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA",
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA",
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
tls.TLS_FALLBACK_SCSV: "TLS_FALLBACK_SCSV",
}

View File

@@ -1,5 +1,4 @@
// +build go1.7
// +build !go1.8
// +build go1.7,!go1.8
/*
*

View File

@@ -24,6 +24,14 @@ import (
"crypto/tls"
)
func init() {
cipherSuiteLookup[tls.TLS_RSA_WITH_AES_128_CBC_SHA256] = "TLS_RSA_WITH_AES_128_CBC_SHA256"
cipherSuiteLookup[tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256] = "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256"
cipherSuiteLookup[tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256] = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"
cipherSuiteLookup[tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305] = "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305"
cipherSuiteLookup[tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305] = "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305"
}
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.

35
vendor/google.golang.org/grpc/credentials/go19.go generated vendored Normal file
View File

@@ -0,0 +1,35 @@
// +build go1.9,!appengine
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package credentials
import (
"errors"
"syscall"
)
// implements the syscall.Conn interface
func (c tlsConn) SyscallConn() (syscall.RawConn, error) {
conn, ok := c.rawConn.(syscall.Conn)
if !ok {
return nil, errors.New("RawConn does not implement syscall.Conn")
}
return conn.SyscallConn()
}

View File

@@ -0,0 +1,101 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package google defines credentials for google cloud services.
package google
import (
"fmt"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/alts"
"google.golang.org/grpc/credentials/oauth"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
)
const tokenRequestTimeout = 30 * time.Second
// NewDefaultCredentials returns a credentials bundle that is configured to work
// with google services.
//
// This API is experimental.
func NewDefaultCredentials() credentials.Bundle {
c := &creds{}
bundle, err := c.NewWithMode(internal.CredsBundleModeFallback)
if err != nil {
grpclog.Warningf("google default creds: failed to create new creds: %v", err)
}
return bundle
}
// creds implements credentials.Bundle.
type creds struct {
// Supported modes are defined in internal/internal.go.
mode string
// The transport credentials associated with this bundle.
transportCreds credentials.TransportCredentials
// The per RPC credentials associated with this bundle.
perRPCCreds credentials.PerRPCCredentials
}
func (c *creds) TransportCredentials() credentials.TransportCredentials {
return c.transportCreds
}
func (c *creds) PerRPCCredentials() credentials.PerRPCCredentials {
if c == nil {
return nil
}
return c.perRPCCreds
}
// NewWithMode should make a copy of Bundle, and switch mode. Modifying the
// existing Bundle may cause races.
func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) {
newCreds := &creds{mode: mode}
// Create transport credentials.
switch mode {
case internal.CredsBundleModeFallback:
newCreds.transportCreds = credentials.NewTLS(nil)
case internal.CredsBundleModeBackendFromBalancer, internal.CredsBundleModeBalancer:
// Only the clients can use google default credentials, so we only need
// to create new ALTS client creds here.
newCreds.transportCreds = alts.NewClientCreds(alts.DefaultClientOptions())
default:
return nil, fmt.Errorf("google default creds: unsupported mode: %v", mode)
}
if mode == internal.CredsBundleModeFallback || mode == internal.CredsBundleModeBackendFromBalancer {
// Create per RPC credentials.
// For the time being, we required per RPC credentials for both TLS and
// ALTS. In the future, this will only be required for TLS.
ctx, cancel := context.WithTimeout(context.Background(), tokenRequestTimeout)
defer cancel()
var err error
newCreds.perRPCCreds, err = oauth.NewApplicationDefault(ctx)
if err != nil {
grpclog.Warningf("google default creds: failed to create application oauth: %v", err)
}
}
return newCreds, nil
}

465
vendor/google.golang.org/grpc/dialoptions.go generated vendored Normal file
View File

@@ -0,0 +1,465 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"fmt"
"net"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats"
)
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
cp Compressor
dc Decompressor
bs backoff.Strategy
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
authority string
copts transport.ConnectOptions
callOptions []CallOption
// This is used by v1 balancer dial option WithBalancer to support v1
// balancer, and also by WithBalancerName dial option.
balancerBuilder balancer.Builder
// This is to support grpclb.
resolverBuilder resolver.Builder
waitForHandshake bool
channelzParentID int64
disableServiceConfig bool
disableRetry bool
}
// DialOption configures how we set up the connection.
type DialOption interface {
apply(*dialOptions)
}
// EmptyDialOption does not alter the dial configuration. It can be embedded in
// another structure to build custom dial options.
//
// This API is EXPERIMENTAL.
type EmptyDialOption struct{}
func (EmptyDialOption) apply(*dialOptions) {}
// funcDialOption wraps a function that modifies dialOptions into an
// implementation of the DialOption interface.
type funcDialOption struct {
f func(*dialOptions)
}
func (fdo *funcDialOption) apply(do *dialOptions) {
fdo.f(do)
}
func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
return &funcDialOption{
f: f,
}
}
// WithWaitForHandshake blocks until the initial settings frame is received from
// the server before assigning RPCs to the connection. Experimental API.
func WithWaitForHandshake() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.waitForHandshake = true
})
}
// WithWriteBufferSize determines how much data can be batched before doing a
// write on the wire. The corresponding memory allocation for this buffer will
// be twice the size to keep syscalls low. The default value for this buffer is
// 32KB.
//
// Zero will disable the write buffer such that each write will be on underlying
// connection. Note: A Send call may not directly translate to a write.
func WithWriteBufferSize(s int) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.WriteBufferSize = s
})
}
// WithReadBufferSize lets you set the size of read buffer, this determines how
// much data can be read at most for each read syscall.
//
// The default value for this buffer is 32KB. Zero will disable read buffer for
// a connection so data framer can access the underlying conn directly.
func WithReadBufferSize(s int) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.ReadBufferSize = s
})
}
// WithInitialWindowSize returns a DialOption which sets the value for initial
// window size on a stream. The lower bound for window size is 64K and any value
// smaller than that will be ignored.
func WithInitialWindowSize(s int32) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.InitialWindowSize = s
})
}
// WithInitialConnWindowSize returns a DialOption which sets the value for
// initial window size on a connection. The lower bound for window size is 64K
// and any value smaller than that will be ignored.
func WithInitialConnWindowSize(s int32) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.InitialConnWindowSize = s
})
}
// WithMaxMsgSize returns a DialOption which sets the maximum message size the
// client can receive.
//
// Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead.
func WithMaxMsgSize(s int) DialOption {
return WithDefaultCallOptions(MaxCallRecvMsgSize(s))
}
// WithDefaultCallOptions returns a DialOption which sets the default
// CallOptions for calls over the connection.
func WithDefaultCallOptions(cos ...CallOption) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.callOptions = append(o.callOptions, cos...)
})
}
// WithCodec returns a DialOption which sets a codec for message marshaling and
// unmarshaling.
//
// Deprecated: use WithDefaultCallOptions(CallCustomCodec(c)) instead.
func WithCodec(c Codec) DialOption {
return WithDefaultCallOptions(CallCustomCodec(c))
}
// WithCompressor returns a DialOption which sets a Compressor to use for
// message compression. It has lower priority than the compressor set by the
// UseCompressor CallOption.
//
// Deprecated: use UseCompressor instead.
func WithCompressor(cp Compressor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.cp = cp
})
}
// WithDecompressor returns a DialOption which sets a Decompressor to use for
// incoming message decompression. If incoming response messages are encoded
// using the decompressor's Type(), it will be used. Otherwise, the message
// encoding will be used to look up the compressor registered via
// encoding.RegisterCompressor, which will then be used to decompress the
// message. If no compressor is registered for the encoding, an Unimplemented
// status error will be returned.
//
// Deprecated: use encoding.RegisterCompressor instead.
func WithDecompressor(dc Decompressor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.dc = dc
})
}
// WithBalancer returns a DialOption which sets a load balancer with the v1 API.
// Name resolver will be ignored if this DialOption is specified.
//
// Deprecated: use the new balancer APIs in balancer package and
// WithBalancerName.
func WithBalancer(b Balancer) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.balancerBuilder = &balancerWrapperBuilder{
b: b,
}
})
}
// WithBalancerName sets the balancer that the ClientConn will be initialized
// with. Balancer registered with balancerName will be used. This function
// panics if no balancer was registered by balancerName.
//
// The balancer cannot be overridden by balancer option specified by service
// config.
//
// This is an EXPERIMENTAL API.
func WithBalancerName(balancerName string) DialOption {
builder := balancer.Get(balancerName)
if builder == nil {
panic(fmt.Sprintf("grpc.WithBalancerName: no balancer is registered for name %v", balancerName))
}
return newFuncDialOption(func(o *dialOptions) {
o.balancerBuilder = builder
})
}
// withResolverBuilder is only for grpclb.
func withResolverBuilder(b resolver.Builder) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.resolverBuilder = b
})
}
// WithServiceConfig returns a DialOption which has a channel to read the
// service configuration.
//
// Deprecated: service config should be received through name resolver, as
// specified here.
// https://github.com/grpc/grpc/blob/master/doc/service_config.md
func WithServiceConfig(c <-chan ServiceConfig) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.scChan = c
})
}
// WithBackoffMaxDelay configures the dialer to use the provided maximum delay
// when backing off after failed connection attempts.
func WithBackoffMaxDelay(md time.Duration) DialOption {
return WithBackoffConfig(BackoffConfig{MaxDelay: md})
}
// WithBackoffConfig configures the dialer to use the provided backoff
// parameters after connection failures.
//
// Use WithBackoffMaxDelay until more parameters on BackoffConfig are opened up
// for use.
func WithBackoffConfig(b BackoffConfig) DialOption {
return withBackoff(backoff.Exponential{
MaxDelay: b.MaxDelay,
})
}
// withBackoff sets the backoff strategy used for connectRetryNum after a failed
// connection attempt.
//
// This can be exported if arbitrary backoff strategies are allowed by gRPC.
func withBackoff(bs backoff.Strategy) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.bs = bs
})
}
// WithBlock returns a DialOption which makes caller of Dial blocks until the
// underlying connection is up. Without this, Dial returns immediately and
// connecting the server happens in background.
func WithBlock() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.block = true
})
}
// WithInsecure returns a DialOption which disables transport security for this
// ClientConn. Note that transport security is required unless WithInsecure is
// set.
func WithInsecure() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.insecure = true
})
}
// WithTransportCredentials returns a DialOption which configures a connection
// level security credentials (e.g., TLS/SSL). This should not be used together
// with WithCredentialsBundle.
func WithTransportCredentials(creds credentials.TransportCredentials) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.TransportCredentials = creds
})
}
// WithPerRPCCredentials returns a DialOption which sets credentials and places
// auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
})
}
// WithCredentialsBundle returns a DialOption to set a credentials bundle for
// the ClientConn.WithCreds. This should not be used together with
// WithTransportCredentials.
//
// This API is experimental.
func WithCredentialsBundle(b credentials.Bundle) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.CredsBundle = b
})
}
// WithTimeout returns a DialOption that configures a timeout for dialing a
// ClientConn initially. This is valid if and only if WithBlock() is present.
//
// Deprecated: use DialContext and context.WithTimeout instead.
func WithTimeout(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.timeout = d
})
}
func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.Dialer = f
})
}
func init() {
internal.WithContextDialer = withContextDialer
internal.WithResolverBuilder = withResolverBuilder
}
// WithDialer returns a DialOption that specifies a function to use for dialing
// network addresses. If FailOnNonTempDialError() is set to true, and an error
// is returned by f, gRPC checks the error's Temporary() method to decide if it
// should try to reconnect to the network address.
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return withContextDialer(
func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now()))
}
return f(addr, 0)
})
}
// WithStatsHandler returns a DialOption that specifies the stats handler for
// all the RPCs and underlying network connections in this ClientConn.
func WithStatsHandler(h stats.Handler) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.StatsHandler = h
})
}
// FailOnNonTempDialError returns a DialOption that specifies if gRPC fails on
// non-temporary dial errors. If f is true, and dialer returns a non-temporary
// error, gRPC will fail the connection to the network address and won't try to
// reconnect. The default value of FailOnNonTempDialError is false.
//
// FailOnNonTempDialError only affects the initial dial, and does not do
// anything useful unless you are also using WithBlock().
//
// This is an EXPERIMENTAL API.
func FailOnNonTempDialError(f bool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.FailOnNonTempDialError = f
})
}
// WithUserAgent returns a DialOption that specifies a user agent string for all
// the RPCs.
func WithUserAgent(s string) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.UserAgent = s
})
}
// WithKeepaliveParams returns a DialOption that specifies keepalive parameters
// for the client transport.
func WithKeepaliveParams(kp keepalive.ClientParameters) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.KeepaliveParams = kp
})
}
// WithUnaryInterceptor returns a DialOption that specifies the interceptor for
// unary RPCs.
func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.unaryInt = f
})
}
// WithStreamInterceptor returns a DialOption that specifies the interceptor for
// streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.streamInt = f
})
}
// WithAuthority returns a DialOption that specifies the value to be used as the
// :authority pseudo-header. This value only works with WithInsecure and has no
// effect if TransportCredentials are present.
func WithAuthority(a string) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.authority = a
})
}
// WithChannelzParentID returns a DialOption that specifies the channelz ID of
// current ClientConn's parent. This function is used in nested channel creation
// (e.g. grpclb dial).
func WithChannelzParentID(id int64) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.channelzParentID = id
})
}
// WithDisableServiceConfig returns a DialOption that causes grpc to ignore any
// service config provided by the resolver and provides a hint to the resolver
// to not fetch service configs.
func WithDisableServiceConfig() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.disableServiceConfig = true
})
}
// WithDisableRetry returns a DialOption that disables retries, even if the
// service config enables them. This does not impact transparent retries, which
// will happen automatically if no data is written to the wire or if the RPC is
// unprocessed by the remote server.
//
// Retry support is currently disabled by default, but will be enabled by
// default in the future. Until then, it may be enabled by setting the
// environment variable "GRPC_GO_RETRY" to "on".
//
// This API is EXPERIMENTAL.
func WithDisableRetry() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.disableRetry = true
})
}
// WithMaxHeaderListSize returns a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept.
func WithMaxHeaderListSize(s uint32) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.MaxHeaderListSize = &s
})
}
func defaultDialOptions() dialOptions {
return dialOptions{
disableRetry: !envconfig.Retry,
copts: transport.ConnectOptions{
WriteBufferSize: defaultWriteBufSize,
ReadBufferSize: defaultReadBufSize,
},
}
}

View File

@@ -82,7 +82,7 @@ type Codec interface {
Name() string
}
var registeredCodecs = make(map[string]Codec, 0)
var registeredCodecs = make(map[string]Codec)
// RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers.

View File

@@ -23,6 +23,7 @@ package gzip
import (
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"sync"
@@ -46,6 +47,26 @@ type writer struct {
pool *sync.Pool
}
// SetLevel updates the registered gzip compressor to use the compression level specified (gzip.HuffmanOnly is not supported).
// NOTE: this function must only be called during initialization time (i.e. in an init() function),
// and is not thread-safe.
//
// The error returned will be nil if the specified level is valid.
func SetLevel(level int) error {
if level < gzip.DefaultCompression || level > gzip.BestCompression {
return fmt.Errorf("grpc: invalid gzip compression level: %d", level)
}
c := encoding.GetCompressor(Name).(*compressor)
c.poolCompressor.New = func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, level)
if err != nil {
panic(err)
}
return &writer{Writer: w, pool: &c.poolCompressor}
}
return nil
}
func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
z := c.poolCompressor.Get().(*writer)
z.Writer.Reset(w)

View File

@@ -40,7 +40,7 @@ func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte
t.Errorf("codec.Unmarshal(_) returned an error")
}
if bytes.Compare(p.GetBody(), expectedBody) != 0 {
if !bytes.Equal(p.GetBody(), expectedBody) {
t.Errorf("Unexpected body; got %v; want %v", p.GetBody(), expectedBody)
}
}

View File

@@ -313,7 +313,7 @@ Now let's look at how we call our service methods. Note that in gRPC-Go, RPCs op
Calling the simple RPC `GetFeature` is nearly as straightforward as calling a local method.
```go
feature, err := client.GetFeature(context.Background(), &pb.Point{409146138, -746188906})
feature, err := client.GetFeature(ctx, &pb.Point{409146138, -746188906})
if err != nil {
...
}
@@ -331,7 +331,7 @@ Here's where we call the server-side streaming method `ListFeatures`, which retu
```go
rect := &pb.Rectangle{ ... } // initialize a pb.Rectangle
stream, err := client.ListFeatures(context.Background(), rect)
stream, err := client.ListFeatures(ctx, rect)
if err != nil {
...
}
@@ -364,7 +364,7 @@ for i := 0; i < pointCount; i++ {
points = append(points, randomPoint(r))
}
log.Printf("Traversing %d points.", len(points))
stream, err := client.RecordRoute(context.Background())
stream, err := client.RecordRoute(ctx)
if err != nil {
log.Fatalf("%v.RecordRoute(_) = _, %v", client, err)
}
@@ -387,7 +387,7 @@ The `RouteGuide_RecordRouteClient` has a `Send()` method that we can use to send
Finally, let's look at our bidirectional streaming RPC `RouteChat()`. As in the case of `RecordRoute`, we only pass the method a context object and get back a stream that we can use to both write and read messages. However, this time we return values via our method's stream while the server is still writing messages to *their* message stream.
```go
stream, err := client.RouteChat(context.Background())
stream, err := client.RouteChat(ctx)
waitc := make(chan struct{})
go func() {
for {

View File

@@ -21,6 +21,7 @@ package main
import (
"log"
"os"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
@@ -46,7 +47,9 @@ func main() {
if len(os.Args) > 1 {
name = os.Args[1]
}
r, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: name})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
r, err := c.SayHello(ctx, &pb.HelloRequest{Name: name})
if err != nil {
log.Fatalf("could not greet: %v", err)
}

Some files were not shown because too many files have changed in this diff Show More