Merge pull request #2101 from akvorado/fix/outlet-shutdown
Some checks failed
CI / 🤖 Check dependabot status (push) Has been cancelled
CI / 🐧 Test on Linux (${{ github.ref_type == 'tag' }}, misc) (push) Has been cancelled
CI / 🐧 Test on Linux (coverage) (push) Has been cancelled
CI / 🐧 Test on Linux (regular) (push) Has been cancelled
CI / ❄️ Build on Nix (push) Has been cancelled
CI / 🍏 Build and test on macOS (push) Has been cancelled
CI / 🧪 End-to-end testing (push) Has been cancelled
CI / 🔍 Upload code coverage (push) Has been cancelled
CI / 🔬 Test only Go (push) Has been cancelled
CI / 🔬 Test only JS (${{ needs.dependabot.outputs.package-ecosystem }}, 20) (push) Has been cancelled
CI / 🔬 Test only JS (${{ needs.dependabot.outputs.package-ecosystem }}, 22) (push) Has been cancelled
CI / 🔬 Test only JS (${{ needs.dependabot.outputs.package-ecosystem }}, 24) (push) Has been cancelled
CI / ⚖️ Check licenses (push) Has been cancelled
CI / 🐋 Build Docker images (push) Has been cancelled
CI / 🐋 Tag Docker images (push) Has been cancelled
CI / 🚀 Publish release (push) Has been cancelled
Update Nix dependency hashes / Update dependency hashes (push) Has been cancelled

outlet/kafka: prevent discarding flows on shutdown
This commit is contained in:
Vincent Bernat
2025-11-18 21:25:39 +01:00
committed by GitHub
29 changed files with 898 additions and 191 deletions

View File

@@ -36,6 +36,7 @@ type OutletConfiguration struct {
Kafka kafka.Configuration
ClickHouseDB clickhousedb.Configuration
ClickHouse clickhouse.Configuration
Flow flow.Configuration
Core core.Configuration
Schema schema.Configuration
}
@@ -50,6 +51,7 @@ func (c *OutletConfiguration) Reset() {
Kafka: kafka.DefaultConfiguration(),
ClickHouseDB: clickhousedb.DefaultConfiguration(),
ClickHouse: clickhouse.DefaultConfiguration(),
Flow: flow.DefaultConfiguration(),
Core: core.DefaultConfiguration(),
Schema: schema.DefaultConfiguration(),
}
@@ -111,7 +113,7 @@ func outletStart(r *reporter.Reporter, config OutletConfiguration, checkOnly boo
if err != nil {
return fmt.Errorf("unable to initialize schema component: %w", err)
}
flowComponent, err := flow.New(r, flow.Dependencies{
flowComponent, err := flow.New(r, config.Flow, flow.Dependencies{
Schema: schemaComponent,
})
if err != nil {

View File

@@ -9,29 +9,18 @@ import (
"fmt"
"os"
"path/filepath"
"github.com/google/renameio/v2"
)
// Save persists the cache to the specified file
func (c *Cache[K, V]) Save(cacheFile string) error {
tmpFile, err := os.CreateTemp(
filepath.Dir(cacheFile),
fmt.Sprintf("%s-*", filepath.Base(cacheFile)))
if err != nil {
return fmt.Errorf("unable to create cache file %q: %w", cacheFile, err)
}
defer func() {
tmpFile.Close() // ignore errors
os.Remove(tmpFile.Name()) // ignore errors
}()
// Write cache
encoder := gob.NewEncoder(tmpFile)
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
if err := encoder.Encode(c); err != nil {
return fmt.Errorf("unable to encode cache: %w", err)
}
// Move cache to new location
if err := os.Rename(tmpFile.Name(), cacheFile); err != nil {
if err := renameio.WriteFile(cacheFile, buf.Bytes(), 0o666, renameio.WithTempDir(filepath.Dir(cacheFile))); err != nil {
return fmt.Errorf("unable to write cache file %q: %w", cacheFile, err)
}
return nil
@@ -53,7 +42,7 @@ func (c *Cache[K, V]) Load(cacheFile string) error {
// currentVersionNumber should be increased each time we change the way we
// encode the cache.
var currentVersionNumber = 11
const currentVersionNumber = 11
// GobEncode encodes the cache
func (c *Cache[K, V]) GobEncode() ([]byte, error) {
@@ -61,7 +50,8 @@ func (c *Cache[K, V]) GobEncode() ([]byte, error) {
encoder := gob.NewEncoder(&buf)
// Encode version
if err := encoder.Encode(&currentVersionNumber); err != nil {
version := currentVersionNumber
if err := encoder.Encode(&version); err != nil {
return nil, err
}
// Encode a representation of K and V. Gob decoding is pretty forgiving, we

View File

@@ -46,15 +46,15 @@ func (l *Logger) Log(level kgo.LogLevel, msg string, keyvals ...any) {
}
// Logf logs a message at the specified level for kfake.
func (l *Logger) Logf(level kfake.LogLevel, msg string, keyvals ...any) {
func (l *Logger) Logf(level kfake.LogLevel, msg string, args ...any) {
switch level {
case kfake.LogLevelError:
l.r.Error().Fields(keyvals).Msg(msg)
l.r.Error().Msgf(msg, args...)
case kfake.LogLevelWarn:
l.r.Warn().Fields(keyvals).Msg(msg)
l.r.Warn().Msgf(msg, args...)
case kfake.LogLevelInfo:
l.r.Info().Fields(keyvals).Msg(msg)
l.r.Info().Msgf(msg, args...)
case kfake.LogLevelDebug:
l.r.Debug().Fields(keyvals).Msg(msg)
l.r.Debug().Msgf(msg, args...)
}
}

View File

@@ -19,6 +19,7 @@ var Version = 5
var decoderMap = bimap.New(map[RawFlow_Decoder]string{
RawFlow_DECODER_NETFLOW: "netflow",
RawFlow_DECODER_SFLOW: "sflow",
RawFlow_DECODER_GOB: "gob",
})
// MarshalText turns a decoder to text
@@ -27,7 +28,7 @@ func (d RawFlow_Decoder) MarshalText() ([]byte, error) {
if ok {
return []byte(got), nil
}
return nil, errors.New("unknown decoder")
return nil, fmt.Errorf("unknown decoder %d", d)
}
// UnmarshalText provides a decoder from text

View File

@@ -608,7 +608,7 @@ exporter-classifiers:
### ClickHouse
The ClickHouse component pushes data to ClickHouse. There are two settings that
The ClickHouse component pushes data to ClickHouse. There are three settings that
are configurable:
- `maximum-batch-size` defines how many flows to send to ClickHouse in a single batch at most
@@ -621,6 +621,14 @@ send a batch of size at most `maximum-batch-size` at least every
The default value is 100 000 and allows ClickHouse to handle incoming flows
efficiently.
### Flow
The flow component decodes flows received from Kafka. There is only one setting:
- `state-persist-file` defines the location of the file to save the state of the
flow decoders and read it back on startup. It is used to store IPFIX/NetFlow
templates and options.
## Orchestrator service
The three main components of the orchestrator service are `schema`,

View File

@@ -15,11 +15,14 @@ identified with a specific icon:
- 💥 *config*: `skip-verify` is false by default in TLS configurations for
ClickHouse, Kafka and remote data sources (previously, `verify` was set to
false by default)
- 🩹 *inlet*: keep flows from one exporter into a single partition
- 🩹 *outlet*: provide additional gracetime for a worker to send to ClickHouse
- 🩹 *outlet*: prevent discarding flows on shutdown
- 🩹 *outlet*: enhance scaling up and down workers to avoid hysteresis
- 🩹 *outlet*: accept flows where interface names or descriptions are missing
- 🩹 *docker*: update Traefik to 3.6.1 (for compatibility with Docker Engine 29)
- 🌱 *common*: enable block and mutex profiling
- 🌱 *outlet*: save IPFIX decoder state to a file to prevent discarding flows on start
- 🌱 *config*: rename `verify` to `skip-verify` in TLS configurations for
ClickHouse, Kafka and remote data sources (with inverted logic)
- 🌱 *config*: remote data sources accept a specific TLS configuration

View File

@@ -177,6 +177,7 @@ services:
ports:
- 10179:10179/tcp
restart: unless-stopped
stop_grace_period: 30s
depends_on:
akvorado-orchestrator:
condition: service_healthy
@@ -189,6 +190,7 @@ services:
- akvorado-run:/run/akvorado
environment:
AKVORADO_CFG_OUTLET_METADATA_CACHEPERSISTFILE: /run/akvorado/metadata.cache
AKVORADO_CFG_OUTLET_FLOW_STATEPERSISTFILE: /run/akvorado/flow.state
labels:
- traefik.enable=true
# Disable access logging of /api/v0/outlet/metrics

13
go.mod
View File

@@ -83,8 +83,11 @@ require (
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/cosiner/argv v0.1.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect
github.com/creack/pty v1.1.24 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/derekparker/trie v0.0.0-20230829180723-39f4de51ef7d // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
@@ -101,6 +104,8 @@ require (
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-delve/delve v1.25.2 // indirect
github.com/go-delve/liner v1.2.3-0.20231231155935-4726ab1d7f62 // indirect
github.com/go-faster/city v1.0.1 // indirect
github.com/go-faster/errors v0.7.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
@@ -112,11 +117,14 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-dap v0.12.0 // indirect
github.com/google/licensecheck v0.3.1 // indirect
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
github.com/google/renameio/v2 v2.0.0 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect
github.com/hashicorp/golang-lru v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/itchyny/timefmt-go v0.1.6 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -135,6 +143,7 @@ require (
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/mgechev/dots v1.0.0 // indirect
@@ -160,7 +169,9 @@ require (
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.54.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/shirou/gopsutil/v3 v3.23.12 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
@@ -181,6 +192,7 @@ require (
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.starlark.net v0.0.0-20231101134539-556fd59b42f6 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
@@ -207,6 +219,7 @@ require (
tool (
github.com/dmarkham/enumer
github.com/frapposelli/wwhrd
github.com/go-delve/delve/cmd/dlv
github.com/mgechev/revive
github.com/mna/pigeon
github.com/planetscale/vtprotobuf/cmd/protoc-gen-go-vtproto

26
go.sum
View File

@@ -62,7 +62,10 @@ github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cosiner/argv v0.1.0 h1:BVDiEL32lwHukgJKP87btEPenzrrHUjajs/8yzaqcXg=
github.com/cosiner/argv v0.1.0/go.mod h1:EusR6TucWKX+zFgtdUsKT2Cvg45K5rtpCcWz4hK06d8=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.6 h1:XJtiaUW6dEEqVuZiMTn1ldk455QWwEIsMIJlo5vtkx0=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
@@ -71,6 +74,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/derekparker/trie v0.0.0-20230829180723-39f4de51ef7d h1:hUWoLdw5kvo2xCsqlsIBMvWUc1QCSsCYD2J2+Fg6YoU=
github.com/derekparker/trie v0.0.0-20230829180723-39f4de51ef7d/go.mod h1:C7Es+DLenIpPc9J6IYw4jrK0h7S9bKj4DNl8+KxGEXU=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -121,6 +126,10 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-delve/delve v1.25.2 h1:EI6EIWGKUEC7OVE5nfG2eQSv5xEgCRxO1+REB7FKCtE=
github.com/go-delve/delve v1.25.2/go.mod h1:sBjdpmDVpQd8nIMFldtqJZkk0RpGXrf8AAp5HeRi0CM=
github.com/go-delve/liner v1.2.3-0.20231231155935-4726ab1d7f62 h1:IGtvsNyIuRjl04XAOFGACozgUD7A82UffYxZt4DWbvA=
github.com/go-delve/liner v1.2.3-0.20231231155935-4726ab1d7f62/go.mod h1:biJCRbqp51wS+I92HMqn5H8/A0PAhxn2vyOT+JqhiGI=
github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw=
github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw=
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
@@ -185,6 +194,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-dap v0.12.0 h1:rVcjv3SyMIrpaOoTAdFDyHs99CwVOItIJGKLQFQhNeM=
github.com/google/go-dap v0.12.0/go.mod h1:tNjCASCm5cqePi/RVXXWEVqtnNLV1KTWtYOqu6rZNzc=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -193,6 +204,8 @@ github.com/google/licensecheck v0.3.1/go.mod h1:ORkR35t/JjW+emNKtfJDII0zlciG9Jgb
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg=
github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -206,6 +219,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 h1:pRhl55Yx1eC7BZ1N+BBWwn
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0/go.mod h1:XKMd7iuf/RGPSMJ/U4HP0zS2Z9Fh8Ps9a+6X26m/tmI=
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c=
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
@@ -270,6 +285,9 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
@@ -360,13 +378,18 @@ github.com/quic-go/quic-go v0.54.1/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQ
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/russross/blackfriday v1.6.0 h1:KqfZb0pUVN2lYqZUYRddxF4OR8ZMURnJIG5Y3VRLtww=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/scrapli/scrapligo v1.3.3 h1:D9zj1QrOYNYAQ30YT7wfQBINvPGxvs5L5Lz+2LnL7V4=
github.com/scrapli/scrapligo v1.3.3/go.mod h1:pOWxVyPsQRrWTrkoSSDg05tjOqtWfLffAZtAsCc0w3M=
@@ -473,6 +496,8 @@ go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.starlark.net v0.0.0-20231101134539-556fd59b42f6 h1:+eC0F/k4aBLC4szgOcjd7bDTEnpxADJyWJE0yowgM3E=
go.starlark.net v0.0.0-20231101134539-556fd59b42f6/go.mod h1:LcLNIzVOMp4oV+uusnpk+VU+SzXaJakUuBjoCSWH5dM=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
@@ -557,6 +582,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -6,9 +6,7 @@ package kafka
import (
"context"
"encoding/binary"
"fmt"
"math/rand/v2"
"strings"
"time"
@@ -109,11 +107,9 @@ func (c *Component) Stop() error {
// Send a message to Kafka.
func (c *Component) Send(exporter string, payload []byte, finalizer func()) {
key := make([]byte, 4)
binary.BigEndian.PutUint32(key, rand.Uint32())
record := &kgo.Record{
Topic: c.kafkaTopic,
Key: key,
Key: []byte(exporter),
Value: payload,
}
c.kafkaClient.Produce(context.Background(), record, func(r *kgo.Record, err error) {

View File

@@ -626,7 +626,7 @@ ClassifyProviderRegex(Interface.Description, "^Transit: ([^ ]+)", "$1")`,
daemonComponent := daemon.NewMock(t)
metadataComponent := metadata.NewMock(t, r, metadata.DefaultConfiguration(),
metadata.Dependencies{Daemon: daemonComponent})
flowComponent, err := flow.New(r, flow.Dependencies{Schema: schema.NewMock(t)})
flowComponent, err := flow.New(r, flow.DefaultConfiguration(), flow.Dependencies{Schema: schema.NewMock(t)})
if err != nil {
t.Fatalf("flow.New() error:\n%+v", err)
}

View File

@@ -101,6 +101,7 @@ func (c *Component) Stop() error {
c.r.Info().Msg("core component stopped")
}()
c.r.Info().Msg("stopping core component")
c.d.Kafka.StopWorkers()
c.t.Kill(nil)
return c.t.Wait()
}

View File

@@ -40,7 +40,7 @@ func TestCore(t *testing.T) {
daemonComponent := daemon.NewMock(t)
metadataComponent := metadata.NewMock(t, r, metadata.DefaultConfiguration(),
metadata.Dependencies{Daemon: daemonComponent})
flowComponent, err := flow.New(r, flow.Dependencies{Schema: schema.NewMock(t)})
flowComponent, err := flow.New(r, flow.DefaultConfiguration(), flow.Dependencies{Schema: schema.NewMock(t)})
if err != nil {
t.Fatalf("flow.New() error:\n%+v", err)
}

View File

@@ -53,11 +53,6 @@ func (w *worker) shutdown() {
// processIncomingFlow processes one incoming flow from Kafka.
func (w *worker) processIncomingFlow(ctx context.Context, data []byte) error {
// Do nothing if we are shutting down
if !w.c.t.Alive() {
return kafka.ErrStopProcessing
}
// Raw flow decoding: fatal
w.c.metrics.rawFlowsReceived.Inc()
w.rawFlow.ResetVT()

16
outlet/flow/config.go Normal file
View File

@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: 2025 Free Mobile
// SPDX-License-Identifier: AGPL-3.0-only
package flow
// Configuration describes the configuration for the flow component.
type Configuration struct {
// StatePersistFile defines a file to store decoder state (templates, sampling
// rates) to survive restarts.
StatePersistFile string `validate:"isdefault|filepath"`
}
// DefaultConfiguration returns the default configuration for the flow component.
func DefaultConfiguration() Configuration {
return Configuration{}
}

View File

@@ -65,7 +65,7 @@ func (nd *Decoder) decodeNFv5(packet *netflowlegacy.PacketNetFlowV5, ts, sysUpti
}
}
func (nd *Decoder) decodeNFv9IPFIX(version uint16, obsDomainID uint32, flowSets []any, samplingRateSys *samplingRateSystem, ts, sysUptime uint64, options decoder.Option, bf *schema.FlowMessage, finalize decoder.FinalizeFlowFunc) {
func (nd *Decoder) decodeNFv9IPFIX(version uint16, obsDomainID uint32, flowSets []any, tao *templatesAndOptions, ts, sysUptime uint64, options decoder.Option, bf *schema.FlowMessage, finalize decoder.FinalizeFlowFunc) {
// Look for sampling rate in option data flowsets
for _, flowSet := range flowSets {
switch tFlowSet := flowSet.(type) {
@@ -96,18 +96,18 @@ func (nd *Decoder) decodeNFv9IPFIX(version uint16, obsDomainID uint32, flowSets
samplingRate = (packetInterval + packetSpace) / packetInterval
}
if samplingRate > 0 {
samplingRateSys.SetSamplingRate(version, obsDomainID, samplerID, samplingRate)
tao.SetSamplingRate(version, obsDomainID, samplerID, samplingRate)
}
}
case netflow.DataFlowSet:
for _, record := range tFlowSet.Records {
nd.decodeRecord(version, obsDomainID, samplingRateSys, record.Values, ts, sysUptime, options, bf, finalize)
nd.decodeRecord(version, obsDomainID, tao, record.Values, ts, sysUptime, options, bf, finalize)
}
}
}
}
func (nd *Decoder) decodeRecord(version uint16, obsDomainID uint32, samplingRateSys *samplingRateSystem, fields []netflow.DataField, ts, sysUptime uint64, options decoder.Option, bf *schema.FlowMessage, finalize decoder.FinalizeFlowFunc) {
func (nd *Decoder) decodeRecord(version uint16, obsDomainID uint32, tao *templatesAndOptions, fields []netflow.DataField, ts, sysUptime uint64, options decoder.Option, bf *schema.FlowMessage, finalize decoder.FinalizeFlowFunc) {
var reversePresent *bitset.BitSet
for _, dir := range []direction{directionForward, directionReverse} {
var etype, dstPort, srcPort uint16
@@ -154,7 +154,7 @@ func (nd *Decoder) decodeRecord(version uint16, obsDomainID uint32, samplingRate
case netflow.IPFIX_FIELD_samplingInterval, netflow.IPFIX_FIELD_samplerRandomInterval:
bf.SamplingRate = decodeUNumber(v)
case netflow.IPFIX_FIELD_samplerId, netflow.IPFIX_FIELD_selectorId:
bf.SamplingRate = uint64(samplingRateSys.GetSamplingRate(version, obsDomainID, decodeUNumber(v)))
bf.SamplingRate = uint64(tao.GetSamplingRate(version, obsDomainID, decodeUNumber(v)))
// L3
case netflow.IPFIX_FIELD_sourceIPv4Address:
@@ -346,7 +346,7 @@ func (nd *Decoder) decodeRecord(version uint16, obsDomainID uint32, samplingRate
bf.AppendArrayUInt32(schema.ColumnMPLSLabels, mplsLabels)
}
if bf.SamplingRate == 0 {
bf.SamplingRate = uint64(samplingRateSys.GetSamplingRate(version, obsDomainID, 0))
bf.SamplingRate = uint64(tao.GetSamplingRate(version, obsDomainID, 0))
}
if dir == directionForward && reversePresent == nil {
finalize()

View File

@@ -0,0 +1,126 @@
// SPDX-FileCopyrightText: 2025 Free Mobile
// SPDX-License-Identifier: AGPL-3.0-only
package netflow
import (
"encoding/json"
"fmt"
"reflect"
"github.com/netsampler/goflow2/v2/decoders/netflow"
)
// MarshalText implements encoding.TextMarshaler for templateKey.
func (tk templateKey) MarshalText() ([]byte, error) {
return fmt.Appendf(nil, "%d-%d-%d", tk.version, tk.obsDomainID, tk.templateID), nil
}
// UnmarshalText implements encoding.TextUnmarshaler for templateKey.
func (tk *templateKey) UnmarshalText(text []byte) error {
_, err := fmt.Sscanf(string(text), "%d-%d-%d", &tk.version, &tk.obsDomainID, &tk.templateID)
if err != nil {
return fmt.Errorf("invalid template key %q: %w", string(text), err)
}
return nil
}
// MarshalText implements encoding.TextMarshaler for samplingRateKey.
func (srk samplingRateKey) MarshalText() ([]byte, error) {
return fmt.Appendf(nil, "%d-%d-%d", srk.version, srk.obsDomainID, srk.samplerID), nil
}
// UnmarshalText implements encoding.TextUnmarshaler for samplingRateKey.
func (srk *samplingRateKey) UnmarshalText(text []byte) error {
_, err := fmt.Sscanf(string(text), "%d-%d-%d", &srk.version, &srk.obsDomainID, &srk.samplerID)
if err != nil {
return fmt.Errorf("invalid sampling rate key %q: %w", string(text), err)
}
return nil
}
// MarshalJSON encodes a set of NetFlow templates.
func (t *templates) MarshalJSON() ([]byte, error) {
type typedTemplate struct {
Type string
Template any
}
data := make(map[templateKey]typedTemplate, len(*t))
for k, v := range *t {
switch v := v.(type) {
case netflow.TemplateRecord:
data[k] = typedTemplate{
Type: "data",
Template: v,
}
case netflow.IPFIXOptionsTemplateRecord:
data[k] = typedTemplate{
Type: "ipfix-option",
Template: v,
}
case netflow.NFv9OptionsTemplateRecord:
data[k] = typedTemplate{
Type: "nfv9-option",
Template: v,
}
default:
return nil, fmt.Errorf("unknown template type %q", reflect.TypeOf(v).String())
}
}
return json.Marshal(&data)
}
// UnmarshalJSON decodes a set of NetFlow templates.
func (t *templates) UnmarshalJSON(data []byte) error {
type typedTemplate struct {
Type string
Template json.RawMessage
}
var templatesWithTypes map[templateKey]typedTemplate
if err := json.Unmarshal(data, &templatesWithTypes); err != nil {
return err
}
targetTemplates := make(templates, len(templatesWithTypes))
for k, v := range templatesWithTypes {
var targetTemplate any
var err error
switch v.Type {
case "data":
var tmpl netflow.TemplateRecord
err = json.Unmarshal(v.Template, &tmpl)
targetTemplate = tmpl
case "ipfix-option":
var tmpl netflow.IPFIXOptionsTemplateRecord
err = json.Unmarshal(v.Template, &tmpl)
targetTemplate = tmpl
case "nfv9-option":
var tmpl netflow.NFv9OptionsTemplateRecord
err = json.Unmarshal(v.Template, &tmpl)
targetTemplate = tmpl
default:
return fmt.Errorf("unknown type %q", v.Type)
}
if err != nil {
return err
}
targetTemplates[k] = targetTemplate
}
*t = targetTemplates
return nil
}
// MarshalJSON encodes the NetFlow decoder's collection.
func (nd *Decoder) MarshalJSON() ([]byte, error) {
return json.Marshal(&nd.collection.Collection)
}
// UnmarshalJSON decodes the NetFlow decoder's collection.
func (nd *Decoder) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &nd.collection.Collection); err != nil {
return err
}
for _, tao := range nd.collection.Collection {
tao.nd = nd
}
return nil
}

View File

@@ -0,0 +1,83 @@
// SPDX-FileCopyrightText: 2025 Free Mobile
// SPDX-License-Identifier: AGPL-3.0-only
package netflow
import (
"encoding/json"
"testing"
"akvorado/common/helpers"
"akvorado/common/reporter"
"akvorado/common/schema"
"akvorado/outlet/flow/decoder"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/netsampler/goflow2/v2/decoders/netflow"
)
func TestMarshalUnmarshalTemplates(t *testing.T) {
r := reporter.NewMock(t)
sch := schema.NewMock(t)
nfdecoder := New(r, decoder.Dependencies{Schema: sch})
collection := &nfdecoder.(*Decoder).collection
exporter := collection.Get("::ffff:192.168.1.1")
exporter.SetSamplingRate(10, 300, 10, 2048)
exporter.SetSamplingRate(9, 301, 11, 4096)
exporter.AddTemplate(10, 300, 300, netflow.TemplateRecord{
TemplateId: 300,
FieldCount: 2,
Fields: []netflow.Field{
{
Type: netflow.IPFIX_FIELD_applicationName,
Length: 10,
}, {
Type: netflow.IPFIX_FIELD_VRFname,
Length: 25,
},
},
})
exporter.AddTemplate(10, 300, 301, netflow.IPFIXOptionsTemplateRecord{
TemplateId: 301,
FieldCount: 2,
ScopeFieldCount: 0,
Options: []netflow.Field{
{
Type: netflow.IPFIX_FIELD_samplerRandomInterval,
Length: 4,
}, {
Type: netflow.IPFIX_FIELD_samplerMode,
Length: 4,
},
},
})
exporter.AddTemplate(9, 301, 300, netflow.NFv9OptionsTemplateRecord{
TemplateId: 300,
OptionLength: 2,
Options: []netflow.Field{
{
Type: netflow.NFV9_FIELD_FLOW_ACTIVE_TIMEOUT,
Length: 4,
}, {
Type: netflow.NFV9_FIELD_FORWARDING_STATUS,
Length: 2,
},
},
})
jsonBytes, err := json.Marshal(&nfdecoder)
if err != nil {
t.Fatalf("json.Marshal() error:\n%+v", err)
}
nfdecoder2 := New(r, decoder.Dependencies{Schema: sch})
if err := json.Unmarshal(jsonBytes, &nfdecoder2); err != nil {
t.Fatalf("json.Unmarshal() error:\n%+v", err)
}
collection1 := &nfdecoder.(*Decoder).collection.Collection
collection2 := &nfdecoder2.(*Decoder).collection.Collection
if diff := helpers.Diff(collection1, collection2,
cmpopts.IgnoreUnexported(templatesAndOptions{})); diff != "" {
t.Fatalf("json.Marshal()/json.Unmarshal() (-got, +want):\n%s", diff)
}
}

View File

@@ -9,8 +9,6 @@ import (
"encoding/binary"
"errors"
"fmt"
"strconv"
"sync"
"time"
"github.com/netsampler/goflow2/v2/decoders/netflow"
@@ -29,9 +27,7 @@ type Decoder struct {
errLogger reporter.Logger
// Templates and sampling systems
systemsLock sync.RWMutex
templates map[string]*templateSystem
sampling map[string]*samplingRateSystem
collection templateAndOptionCollection
metrics struct {
errors *reporter.CounterVec
@@ -48,8 +44,10 @@ func New(r *reporter.Reporter, dependencies decoder.Dependencies) decoder.Decode
r: r,
d: dependencies,
errLogger: r.Sample(reporter.BurstSampler(30*time.Second, 3)),
templates: map[string]*templateSystem{},
sampling: map[string]*samplingRateSystem{},
}
nd.collection = templateAndOptionCollection{
nd: nd,
Collection: make(map[string]*templatesAndOptions),
}
nd.metrics.errors = nd.r.CounterVec(
@@ -91,108 +89,13 @@ func New(r *reporter.Reporter, dependencies decoder.Dependencies) decoder.Decode
return nd
}
type templateSystem struct {
nd *Decoder
key string
templates netflow.NetFlowTemplateSystem
}
func (s *templateSystem) AddTemplate(version uint16, obsDomainID uint32, templateID uint16, template any) error {
if err := s.templates.AddTemplate(version, obsDomainID, templateID, template); err != nil {
return nil
}
var typeStr string
switch templateIDConv := template.(type) {
case netflow.IPFIXOptionsTemplateRecord:
templateID = templateIDConv.TemplateId
typeStr = "options_template"
case netflow.NFv9OptionsTemplateRecord:
templateID = templateIDConv.TemplateId
typeStr = "options_template"
case netflow.TemplateRecord:
templateID = templateIDConv.TemplateId
typeStr = "template"
}
s.nd.metrics.templates.WithLabelValues(
s.key,
strconv.Itoa(int(version)),
strconv.Itoa(int(obsDomainID)),
strconv.Itoa(int(templateID)),
typeStr,
).Inc()
return nil
}
func (s *templateSystem) GetTemplate(version uint16, obsDomainID uint32, templateID uint16) (any, error) {
return s.templates.GetTemplate(version, obsDomainID, templateID)
}
func (s *templateSystem) RemoveTemplate(version uint16, obsDomainID uint32, templateID uint16) (any, error) {
return s.templates.RemoveTemplate(version, obsDomainID, templateID)
}
type samplingRateKey struct {
version uint16
obsDomainID uint32
samplerID uint64
}
type samplingRateSystem struct {
lock sync.RWMutex
rates map[samplingRateKey]uint32
}
func (s *samplingRateSystem) GetSamplingRate(version uint16, obsDomainID uint32, samplerID uint64) uint32 {
s.lock.RLock()
defer s.lock.RUnlock()
rate := s.rates[samplingRateKey{
version: version,
obsDomainID: obsDomainID,
samplerID: samplerID,
}]
return rate
}
func (s *samplingRateSystem) SetSamplingRate(version uint16, obsDomainID uint32, samplerID uint64, samplingRate uint32) {
s.lock.Lock()
defer s.lock.Unlock()
s.rates[samplingRateKey{
version: version,
obsDomainID: obsDomainID,
samplerID: samplerID,
}] = samplingRate
}
// Decode decodes a NetFlow payload.
func (nd *Decoder) Decode(in decoder.RawFlow, options decoder.Option, bf *schema.FlowMessage, finalize decoder.FinalizeFlowFunc) (int, error) {
if len(in.Payload) < 2 {
return 0, errors.New("payload too small")
}
key := in.Source.String()
nd.systemsLock.RLock()
templates, tok := nd.templates[key]
sampling, sok := nd.sampling[key]
nd.systemsLock.RUnlock()
if !tok {
templates = &templateSystem{
nd: nd,
templates: netflow.CreateTemplateSystem(),
key: key,
}
nd.systemsLock.Lock()
nd.templates[key] = templates
nd.systemsLock.Unlock()
}
if !sok {
sampling = &samplingRateSystem{
rates: map[samplingRateKey]uint32{},
}
nd.systemsLock.Lock()
nd.sampling[key] = sampling
nd.systemsLock.Unlock()
}
tao := nd.collection.Get(key)
var (
sysUptime uint64
@@ -230,7 +133,7 @@ func (nd *Decoder) Decode(in decoder.RawFlow, options decoder.Option, bf *schema
nd.decodeNFv5(&packetNFv5, ts, sysUptime, options, bf, finalize2)
case 9:
var packetNFv9 netflow.NFv9Packet
if err := netflow.DecodeMessageNetFlow(buf, templates, &packetNFv9); err != nil {
if err := netflow.DecodeMessageNetFlow(buf, tao, &packetNFv9); err != nil {
if !errors.Is(err, netflow.ErrorTemplateNotFound) {
nd.errLogger.Err(err).Str("exporter", key).Msg("error while decoding NetFlow v9")
nd.metrics.errors.WithLabelValues(key, "NetFlow v9 decoding error").Inc()
@@ -246,10 +149,10 @@ func (nd *Decoder) Decode(in decoder.RawFlow, options decoder.Option, bf *schema
ts = uint64(packetNFv9.UnixSeconds)
sysUptime = uint64(packetNFv9.SystemUptime)
}
nd.decodeNFv9IPFIX(version, obsDomainID, flowSets, sampling, ts, sysUptime, options, bf, finalize2)
nd.decodeNFv9IPFIX(version, obsDomainID, flowSets, tao, ts, sysUptime, options, bf, finalize2)
case 10:
var packetIPFIX netflow.IPFIXPacket
if err := netflow.DecodeMessageIPFIX(buf, templates, &packetIPFIX); err != nil {
if err := netflow.DecodeMessageIPFIX(buf, tao, &packetIPFIX); err != nil {
if !errors.Is(err, netflow.ErrorTemplateNotFound) {
nd.errLogger.Err(err).Str("exporter", key).Msg("error while decoding IPFIX")
nd.metrics.errors.WithLabelValues(key, "IPFIX decoding error").Inc()
@@ -264,7 +167,7 @@ func (nd *Decoder) Decode(in decoder.RawFlow, options decoder.Option, bf *schema
if options.TimestampSource == pb.RawFlow_TS_NETFLOW_PACKET {
ts = uint64(packetIPFIX.ExportTime)
}
nd.decodeNFv9IPFIX(version, obsDomainID, flowSets, sampling, ts, sysUptime, options, bf, finalize2)
nd.decodeNFv9IPFIX(version, obsDomainID, flowSets, tao, ts, sysUptime, options, bf, finalize2)
default:
nd.errLogger.Warn().Str("exporter", key).Msg("unknown NetFlow version")
nd.metrics.packets.WithLabelValues(key, "unknown").

View File

@@ -0,0 +1,141 @@
// SPDX-FileCopyrightText: 2025 Free Mobile
// SPDX-License-Identifier: AGPL-3.0-only
package netflow
import (
"strconv"
"sync"
"github.com/netsampler/goflow2/v2/decoders/netflow"
)
// templateAndOptionCollection map exporters to the set of templates and options we
// received from them.
type templateAndOptionCollection struct {
nd *Decoder
lock sync.Mutex
Collection map[string]*templatesAndOptions
}
// templatesAndOptions contains templates and options associated to an exporter.
type templatesAndOptions struct {
nd *Decoder
templateLock sync.RWMutex
samplingRateLock sync.RWMutex
Key string
Templates templates
SamplingRates map[samplingRateKey]uint32
}
// templates is a mapping to one of netflow.TemplateRecord,
// netflow.IPFIXOptionsTemplateRecord, netflow.NFv9OptionsTemplateRecord.
type templates map[templateKey]any
// templateKey is the key structure to access a template.
type templateKey struct {
version uint16
obsDomainID uint32
templateID uint16
}
// samplingRateKey is the key structure to access a sampling rate.
type samplingRateKey struct {
version uint16
obsDomainID uint32
samplerID uint64
}
var (
_ netflow.NetFlowTemplateSystem = &templatesAndOptions{}
)
// Get returns templates and options for the provided key. If it did not exist,
// it will create a new one.
func (c *templateAndOptionCollection) Get(key string) *templatesAndOptions {
c.lock.Lock()
defer c.lock.Unlock()
t, ok := c.Collection[key]
if ok {
return t
}
t = &templatesAndOptions{
nd: c.nd,
Key: key,
Templates: make(map[templateKey]any),
SamplingRates: make(map[samplingRateKey]uint32),
}
c.Collection[key] = t
return t
}
// RemoveTemplate removes an existing template. This is a noop as it is not
// really needed.
func (t *templatesAndOptions) RemoveTemplate(uint16, uint32, uint16) (any, error) {
return nil, nil
}
// GetTemplate returns the requested template.
func (t *templatesAndOptions) GetTemplate(version uint16, obsDomainID uint32, templateID uint16) (any, error) {
t.templateLock.RLock()
defer t.templateLock.RUnlock()
template, ok := t.Templates[templateKey{version: version, obsDomainID: obsDomainID, templateID: templateID}]
if !ok {
return nil, netflow.ErrorTemplateNotFound
}
return template, nil
}
// AddTemplate stores a template.
func (t *templatesAndOptions) AddTemplate(version uint16, obsDomainID uint32, templateID uint16, template any) error {
var typeStr string
switch templateIDConv := template.(type) {
case netflow.IPFIXOptionsTemplateRecord:
templateID = templateIDConv.TemplateId
typeStr = "options_template"
case netflow.NFv9OptionsTemplateRecord:
templateID = templateIDConv.TemplateId
typeStr = "options_template"
case netflow.TemplateRecord:
templateID = templateIDConv.TemplateId
typeStr = "template"
}
t.nd.metrics.templates.WithLabelValues(
t.Key,
strconv.Itoa(int(version)),
strconv.Itoa(int(obsDomainID)),
strconv.Itoa(int(templateID)),
typeStr,
).Inc()
t.templateLock.Lock()
defer t.templateLock.Unlock()
t.Templates[templateKey{version: version, obsDomainID: obsDomainID, templateID: templateID}] = template
return nil
}
// GetSamplingRate returns the requested sampling rate.
func (t *templatesAndOptions) GetSamplingRate(version uint16, obsDomainID uint32, samplerID uint64) uint32 {
t.samplingRateLock.RLock()
defer t.samplingRateLock.RUnlock()
rate := t.SamplingRates[samplingRateKey{
version: version,
obsDomainID: obsDomainID,
samplerID: samplerID,
}]
return rate
}
// SetSamplingRate sets the sampling rate.
func (t *templatesAndOptions) SetSamplingRate(version uint16, obsDomainID uint32, samplerID uint64, samplingRate uint32) {
t.samplingRateLock.Lock()
defer t.samplingRateLock.Unlock()
t.SamplingRates[samplingRateKey{
version: version,
obsDomainID: obsDomainID,
samplerID: samplerID,
}] = samplingRate
}

View File

@@ -24,10 +24,12 @@ import (
func TestFlowDecode(t *testing.T) {
r := reporter.NewMock(t)
sch := schema.NewMock(t)
c, err := New(r, Dependencies{Schema: sch})
c, err := New(r, DefaultConfiguration(), Dependencies{Schema: sch})
if err != nil {
t.Fatalf("New() error:\n%+v", err)
}
helpers.StartStop(t, c)
bf := sch.NewFlowMessage()
got := []*schema.FlowMessage{}
finalize := func() {

80
outlet/flow/persist.go Normal file
View File

@@ -0,0 +1,80 @@
// SPDX-FileCopyrightText: 2025 Free Mobile
// SPDX-License-Identifier: AGPL-3.0-only
package flow
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"akvorado/common/pb"
"github.com/google/renameio/v2"
)
// ErrStateVersion is triggered when loading a collection from an incompatible version
var ErrStateVersion = errors.New("collection version mismatch")
// currentStateVersionNumber should be increased each time we change the way we
// encode the collection.
const currentStateVersionNumber = 1
// SaveState save the decoders' state to a file. This is not goroutine-safe.
func (c *Component) SaveState(target string) error {
state := struct {
Version int
Decoders any
}{
Version: currentStateVersionNumber,
Decoders: c.decoders,
}
data, err := json.Marshal(&state)
if err != nil {
return fmt.Errorf("unable to encode decoders' state: %w", err)
}
if err := renameio.WriteFile(target, data, 0o666, renameio.WithTempDir(filepath.Dir(target))); err != nil {
return fmt.Errorf("unable to write state file %q: %w", target, err)
}
return nil
}
// RestoreState restores the decoders' state from a file. This is not goroutine-safe.
func (c *Component) RestoreState(source string) error {
data, err := os.ReadFile(source)
if err != nil {
return fmt.Errorf("unable to read state file %q: %w", source, err)
}
// Check the version.
var stateVersion struct {
Version int
}
if err := json.Unmarshal(data, &stateVersion); err != nil {
return err
}
if stateVersion.Version != currentStateVersionNumber {
return ErrStateVersion
}
// Decode decoders.
var stateDecoders struct {
Decoders map[pb.RawFlow_Decoder]json.RawMessage
}
if err := json.Unmarshal(data, &stateDecoders); err != nil {
return fmt.Errorf("unable to decode decoders' state: %w", err)
}
for k, v := range c.decoders {
decoderJSON, ok := stateDecoders.Decoders[k]
if !ok {
continue
}
if err := json.Unmarshal(decoderJSON, &v); err != nil {
return fmt.Errorf("unable to decode decoder' state (%s): %w", k, err)
}
}
return nil
}

154
outlet/flow/persist_test.go Normal file
View File

@@ -0,0 +1,154 @@
// SPDX-FileCopyrightText: 2025 Free Mobile
// SPDX-License-Identifier: AGPL-3.0-only
package flow
import (
"errors"
"net"
"os"
"path"
"path/filepath"
"runtime"
"slices"
"testing"
"time"
"akvorado/common/helpers"
"akvorado/common/pb"
"akvorado/common/reporter"
"akvorado/common/schema"
)
func TestSaveAndRestore(t *testing.T) {
r := reporter.NewMock(t)
sch := schema.NewMock(t)
config := DefaultConfiguration()
config.StatePersistFile = filepath.Join(t.TempDir(), "state")
c, err := New(r, config, Dependencies{Schema: sch})
if err != nil {
t.Fatalf("New() error:\n%+v", err)
}
if err := c.Start(); err != nil {
t.Fatalf("Start() error:\n%+v", err)
}
bf := sch.NewFlowMessage()
_, src, _, _ := runtime.Caller(0)
base := path.Join(path.Dir(src), "decoder", "netflow", "testdata")
for _, pcap := range []string{"options-template.pcap", "options-data.pcap", "template.pcap"} {
data := helpers.ReadPcapL4(t, path.Join(base, pcap))
rawFlow := &pb.RawFlow{
TimeReceived: uint64(time.Now().UnixNano()),
Payload: data,
SourceAddress: net.ParseIP("127.0.0.1").To16(),
UseSourceAddress: false,
Decoder: pb.RawFlow_DECODER_NETFLOW,
TimestampSource: pb.RawFlow_TS_INPUT,
}
err := c.Decode(rawFlow, bf, func() {})
if err != nil {
t.Fatalf("Decode() error:\n%+v", err)
}
}
if err := c.Stop(); err != nil {
t.Fatalf("Stop() error:\n%+v", err)
}
// Create a second component that will reuse saved templates.
r2 := reporter.NewMock(t)
c2, err := New(r2, config, Dependencies{Schema: sch})
if err != nil {
t.Fatalf("New() error:\n%+v", err)
}
if err := c2.Start(); err != nil {
t.Fatalf("Start() error:\n%+v", err)
}
got := []*schema.FlowMessage{}
for _, pcap := range []string{"data.pcap"} {
data := helpers.ReadPcapL4(t, path.Join(base, pcap))
rawFlow := &pb.RawFlow{
TimeReceived: uint64(time.Now().UnixNano()),
Payload: data,
SourceAddress: net.ParseIP("127.0.0.1").To16(),
UseSourceAddress: false,
Decoder: pb.RawFlow_DECODER_NETFLOW,
TimestampSource: pb.RawFlow_TS_INPUT,
}
err := c2.Decode(rawFlow, bf, func() {
clone := *bf
got = append(got, &clone)
bf.Finalize()
})
if err != nil {
t.Fatalf("Decode() error:\n%+v", err)
}
}
if len(got) == 0 {
t.Fatalf("Decode() returned no flows")
}
}
func TestRestoreCorruptedFile(t *testing.T) {
// Create a file with invalid data
tmpDir := t.TempDir()
corruptedFile := filepath.Join(tmpDir, "corrupted.json")
err := os.WriteFile(corruptedFile, []byte("not valid JSON data"), 0644)
if err != nil {
t.Fatalf("WriteFile() error:\n%+v", err)
}
r := reporter.NewMock(t)
sch := schema.NewMock(t)
config := DefaultConfiguration()
config.StatePersistFile = corruptedFile
c, err := New(r, config, Dependencies{Schema: sch})
if err != nil {
t.Fatalf("New() error:\n%+v", err)
}
err = c.RestoreState(corruptedFile)
if err == nil {
t.Error("Restore(): no error")
}
}
func TestRestoreVersionMismatch(t *testing.T) {
// Create a file with a different version number
tmpDir := t.TempDir()
versionMismatchFile := filepath.Join(tmpDir, "version_mismatch.json")
// Write a JSON file with version 999 (incompatible version)
incompatibleData := `{"version":999,"collection":{}}`
err := os.WriteFile(versionMismatchFile, []byte(incompatibleData), 0644)
if err != nil {
t.Fatalf("WriteFile() error:\n%+v", err)
}
r := reporter.NewMock(t)
sch := schema.NewMock(t)
config := DefaultConfiguration()
config.StatePersistFile = versionMismatchFile
c, err := New(r, config, Dependencies{Schema: sch})
if err != nil {
t.Fatalf("New() error:\n%+v", err)
}
err = c.RestoreState(versionMismatchFile)
if err == nil {
t.Fatal("Restore(): expected error for version mismatch, got nil")
}
if !errors.Is(err, ErrStateVersion) {
t.Errorf("Restore(): expected ErrVersion, got %v", err)
}
// Also check we have c.decoders OK
names := []string{}
for _, d := range c.decoders {
names = append(names, d.Name())
}
slices.Sort(names)
if diff := helpers.Diff(names, []string{"gob", "netflow", "sflow"}); diff != "" {
t.Fatalf("RestoreState(): invalid decoders:\n%s", diff)
}
}

View File

@@ -9,7 +9,6 @@ import (
"akvorado/common/pb"
"akvorado/common/reporter"
"akvorado/common/schema"
"akvorado/outlet/flow/decoder"
)
@@ -17,6 +16,7 @@ import (
type Component struct {
r *reporter.Reporter
d *Dependencies
config Configuration
errLogger reporter.Logger
metrics struct {
@@ -29,22 +29,21 @@ type Component struct {
}
// Dependencies are the dependencies of the flow component.
type Dependencies struct {
Schema *schema.Component
}
type Dependencies = decoder.Dependencies
// New creates a new flow component.
func New(r *reporter.Reporter, dependencies Dependencies) (*Component, error) {
func New(r *reporter.Reporter, config Configuration, dependencies Dependencies) (*Component, error) {
c := Component{
r: r,
d: &dependencies,
config: config,
errLogger: r.Sample(reporter.BurstSampler(30*time.Second, 3)),
decoders: make(map[pb.RawFlow_Decoder]decoder.Decoder),
}
// Initialize available decoders
for decoderType, decoderFunc := range availableDecoders {
c.decoders[decoderType] = decoderFunc(r, decoder.Dependencies{Schema: c.d.Schema})
c.decoders[decoderType] = decoderFunc(r, dependencies)
}
// Metrics
@@ -65,3 +64,25 @@ func New(r *reporter.Reporter, dependencies Dependencies) (*Component, error) {
return &c, nil
}
// Start starts the flow component.
func (c *Component) Start() error {
if c.config.StatePersistFile != "" {
if err := c.RestoreState(c.config.StatePersistFile); err != nil {
c.r.Warn().Err(err).Msg("cannot load decoders' state, ignoring")
} else {
c.r.Info().Msg("previous decoders' state loaded")
}
}
return nil
}
// Stop stops the flow component
func (c *Component) Stop() error {
if c.config.StatePersistFile != "" {
if err := c.SaveState(c.config.StatePersistFile); err != nil {
c.r.Err(err).Msg("cannot save decorders' state")
}
}
return nil
}

View File

@@ -40,7 +40,7 @@ type ShutdownFunc func()
type WorkerBuilderFunc func(int, chan<- ScaleRequest) (ReceiveFunc, ShutdownFunc)
// NewConsumer creates a new consumer.
func (c *realComponent) NewConsumer(worker int, callback ReceiveFunc) *Consumer {
func (c *realComponent) newConsumer(worker int, callback ReceiveFunc) *Consumer {
return &Consumer{
r: c.r,
l: c.r.With().Int("worker", worker).Logger(),

View File

@@ -197,6 +197,125 @@ func TestStartSeveralWorkers(t *testing.T) {
}
}
func TestWorkerStop(t *testing.T) {
r := reporter.NewMock(t)
topicName := fmt.Sprintf("test-topic3-%d", rand.Int())
expectedTopicName := fmt.Sprintf("%s-v%d", topicName, pb.Version)
cluster, err := kfake.NewCluster(
kfake.NumBrokers(1),
kfake.SeedTopics(1, expectedTopicName),
kfake.WithLogger(kafka.NewLogger(r)),
)
if err != nil {
t.Fatalf("NewCluster() error: %v", err)
}
defer cluster.Close()
// Start the component
configuration := DefaultConfiguration()
configuration.Topic = topicName
configuration.Brokers = cluster.ListenAddrs()
configuration.FetchMaxWaitTime = 100 * time.Millisecond
configuration.ConsumerGroup = fmt.Sprintf("outlet-%d", rand.Int())
configuration.MinWorkers = 1
c, err := New(r, configuration, Dependencies{Daemon: daemon.NewMock(t)})
if err != nil {
t.Fatalf("New() error:\n%+v", err)
}
helpers.StartStop(t, c)
var last int
done := make(chan bool)
c.StartWorkers(func(int, chan<- ScaleRequest) (ReceiveFunc, ShutdownFunc) {
return func(_ context.Context, got []byte) error {
last, _ = strconv.Atoi(string(got))
return nil
}, func() {
close(done)
}
})
time.Sleep(50 * time.Millisecond)
// Start producing
producerConfiguration := kafka.DefaultConfiguration()
producerConfiguration.Brokers = cluster.ListenAddrs()
producerOpts, err := kafka.NewConfig(reporter.NewMock(t), producerConfiguration)
if err != nil {
t.Fatalf("NewConfig() error:\n%+v", err)
}
producerOpts = append(producerOpts, kgo.ProducerLinger(0))
producer, err := kgo.NewClient(producerOpts...)
if err != nil {
t.Fatalf("NewClient() error:\n%+v", err)
}
defer producer.Close()
produceCtx, cancel := context.WithCancel(t.Context())
defer cancel()
go func() {
for i := 1; ; i++ {
record := &kgo.Record{
Topic: expectedTopicName,
Value: []byte(strconv.Itoa(i)),
}
producer.ProduceSync(produceCtx, record)
time.Sleep(5 * time.Millisecond)
}
}()
// Wait a bit and stop workers
time.Sleep(500 * time.Millisecond)
c.StopWorkers()
select {
case <-done:
default:
t.Fatal("StopWorkers(): worker still running!")
}
gotMetrics := r.GetMetrics("akvorado_outlet_kafka_", "received_messages_total")
expected := map[string]string{
`received_messages_total{worker="0"}`: strconv.Itoa(last),
}
if diff := helpers.Diff(gotMetrics, expected); diff != "" {
t.Fatalf("Metrics (-got, +want):\n%s", diff)
}
// Check that if we consume from the same group, we will resume from last+1
consumerConfiguration := kafka.DefaultConfiguration()
consumerConfiguration.Brokers = cluster.ListenAddrs()
consumerOpts, err := kafka.NewConfig(reporter.NewMock(t), consumerConfiguration)
if err != nil {
t.Fatalf("NewConfig() error:\n%+v", err)
}
consumerOpts = append(consumerOpts,
kgo.ConsumerGroup(configuration.ConsumerGroup),
kgo.ConsumeTopics(expectedTopicName),
kgo.FetchMinBytes(1),
kgo.FetchMaxWait(10*time.Millisecond),
kgo.ConsumeStartOffset(kgo.NewOffset().AtStart()),
)
consumer, err := kgo.NewClient(consumerOpts...)
if err != nil {
t.Fatalf("NewClient() error:\n%+v", err)
}
defer consumer.Close()
fetches := consumer.PollFetches(t.Context())
if fetches.IsClientClosed() {
t.Fatal("PollFetches(): client is closed")
}
fetches.EachError(func(_ string, _ int32, err error) {
t.Fatalf("PollFetches() error:\n%+v", err)
})
var first int
fetches.EachRecord(func(r *kgo.Record) {
if first == 0 {
first, _ = strconv.Atoi(string(r.Value))
}
})
if last+1 != first {
t.Fatalf("PollFetches: %d -> %d", last, first)
}
}
func TestWorkerScaling(t *testing.T) {
r := reporter.NewMock(t)
topicName := fmt.Sprintf("test-topic2-%d", rand.Int())
@@ -204,7 +323,7 @@ func TestWorkerScaling(t *testing.T) {
cluster, err := kfake.NewCluster(
kfake.NumBrokers(1),
kfake.SeedTopics(16, expectedTopicName),
kfake.SeedTopics(4, expectedTopicName),
kfake.WithLogger(kafka.NewLogger(r)),
)
if err != nil {
@@ -241,8 +360,8 @@ func TestWorkerScaling(t *testing.T) {
}
helpers.StartStop(t, c)
if maxWorkers := c.(*realComponent).config.MaxWorkers; maxWorkers != 16 {
t.Errorf("Start() max workers should have been capped to 16 instead of %d", maxWorkers)
if maxWorkers := c.(*realComponent).config.MaxWorkers; maxWorkers != 4 {
t.Errorf("Start() max workers should have been capped to 4 instead of %d", maxWorkers)
}
msg := atomic.Uint32{}
c.StartWorkers(func(_ int, ch chan<- ScaleRequest) (ReceiveFunc, ShutdownFunc) {
@@ -267,7 +386,7 @@ func TestWorkerScaling(t *testing.T) {
"worker_increase_total": "1",
"workers": "1",
"min_workers": "1",
"max_workers": "16",
"max_workers": "4",
}
if diff := helpers.Diff(gotMetrics, expected); diff != "" {
t.Fatalf("Metrics (-got, +want):\n%s", diff)
@@ -281,18 +400,27 @@ func TestWorkerScaling(t *testing.T) {
if results := producer.ProduceSync(context.Background(), record); results.FirstErr() != nil {
t.Fatalf("ProduceSync() error:\n%+v", results.FirstErr())
}
time.Sleep(100 * time.Millisecond)
t.Log("Check if workers increased to 9")
gotMetrics = r.GetMetrics("akvorado_outlet_kafka_", "worker")
expected = map[string]string{
"worker_decrease_total": "0",
"worker_increase_total": "9",
"workers": "9",
var diff string
t.Log("Check if workers increased to 3")
for range 100 {
time.Sleep(10 * time.Millisecond)
gotMetrics = r.GetMetrics("akvorado_outlet_kafka_", "worker")
expected = map[string]string{
"worker_decrease_total": "0",
"worker_increase_total": "3",
"workers": "3",
}
if diff = helpers.Diff(gotMetrics, expected); diff == "" {
break
}
}
if diff := helpers.Diff(gotMetrics, expected); diff != "" {
if diff != "" {
t.Fatalf("Metrics (-got, +want):\n%s", diff)
}
time.Sleep(100 * time.Millisecond)
t.Log("Send 1 message (decrease)")
record = &kgo.Record{
Topic: expectedTopicName,
@@ -301,15 +429,21 @@ func TestWorkerScaling(t *testing.T) {
if results := producer.ProduceSync(context.Background(), record); results.FirstErr() != nil {
t.Fatalf("ProduceSync() error:\n%+v", results.FirstErr())
}
time.Sleep(100 * time.Millisecond)
t.Log("Check if workers decreased to 8")
gotMetrics = r.GetMetrics("akvorado_outlet_kafka_", "worker")
expected = map[string]string{
"worker_decrease_total": "1",
"worker_increase_total": "9",
"workers": "8",
t.Log("Check if workers decreased to 2")
for range 200 {
time.Sleep(10 * time.Millisecond)
gotMetrics = r.GetMetrics("akvorado_outlet_kafka_", "worker")
expected = map[string]string{
"worker_decrease_total": "1",
"worker_increase_total": "3",
"workers": "2",
}
if diff = helpers.Diff(gotMetrics, expected); diff == "" {
break
}
}
if diff := helpers.Diff(gotMetrics, expected); diff != "" {
if diff != "" {
t.Fatalf("Metrics (-got, +want):\n%s", diff)
}
}

View File

@@ -25,6 +25,7 @@ import (
// Component is the interface a Kafka consumer should implement.
type Component interface {
StartWorkers(WorkerBuilderFunc) error
StopWorkers()
Stop() error
}
@@ -74,6 +75,7 @@ func New(r *reporter.Reporter, configuration Configuration, dependencies Depende
kgo.FetchMaxWait(configuration.FetchMaxWaitTime),
kgo.ConsumerGroup(configuration.ConsumerGroup),
kgo.ConsumeStartOffset(kgo.NewOffset().AtEnd()),
kgo.ConsumeResetOffset(kgo.NewOffset().AtEnd()),
kgo.ConsumeTopics(fmt.Sprintf("%s-v%d", configuration.Topic, pb.Version)),
kgo.AutoCommitMarks(),
kgo.AutoCommitInterval(time.Second),
@@ -170,10 +172,19 @@ func (c *realComponent) StartWorkers(workerBuilder WorkerBuilderFunc) error {
return nil
}
// StopWorkers stops all workers
func (c *realComponent) StopWorkers() {
c.workerMu.Lock()
defer c.workerMu.Unlock()
for _, worker := range c.workers {
worker.stop()
}
}
// Stop stops the Kafka component
func (c *realComponent) Stop() error {
defer func() {
c.stopAllWorkers()
c.StopWorkers()
c.kadmClientMu.Lock()
defer c.kadmClientMu.Unlock()
if c.kadmClient != nil {

View File

@@ -50,8 +50,13 @@ func (c *mockComponent) StartWorkers(workerBuilder WorkerBuilderFunc) error {
return nil
}
// StopWorkers stop all workers.
func (c *mockComponent) StopWorkers() {
close(c.incoming)
}
// Stop stops the mock component.
func (c *mockComponent) Stop() error {
close(c.incoming)
c.StopWorkers()
return nil
}

View File

@@ -55,9 +55,10 @@ func (c *realComponent) startOneWorker() error {
return err
}
callback, shutdown := c.workerBuilder(i, c.workerRequestChan)
consumer := c.NewConsumer(i, callback)
consumer := c.newConsumer(i, callback)
// Goroutine for worker
done := make(chan bool)
ctx, cancel := context.WithCancelCause(context.Background())
ctx = c.t.Context(ctx)
c.t.Go(func() error {
@@ -76,6 +77,7 @@ func (c *realComponent) startOneWorker() error {
client.CloseAllowingRebalance()
shutdown()
close(done)
}()
for {
@@ -103,6 +105,7 @@ func (c *realComponent) startOneWorker() error {
c.workers = append(c.workers, worker{
stop: func() {
cancel(ErrStopProcessing)
<-done
},
})
c.metrics.workerIncrease.Inc()
@@ -125,15 +128,6 @@ func (c *realComponent) stopOneWorker() {
c.metrics.workerDecrease.Inc()
}
// stopAllWorkers stops all workers
func (c *realComponent) stopAllWorkers() {
c.workerMu.Lock()
defer c.workerMu.Unlock()
for _, worker := range c.workers {
worker.stop()
}
}
// onPartitionsRevoked is called when partitions are revoked. We need to commit.
func (c *realComponent) onPartitionsRevoked(ctx context.Context, client *kgo.Client, _ map[string][]int32) {
if err := client.CommitMarkedOffsets(ctx); err != nil {