diff options
author | Aleksandr Nogikh <nogikh@google.com> | 2024-05-03 13:12:00 +0200 |
---|---|---|
committer | Dmitry Vyukov <dvyukov@google.com> | 2024-05-16 15:38:27 +0000 |
commit | 03820adaef911ce08278d95f034f134c3c0c852e (patch) | |
tree | 57f87ce0f3dedda459fb1771d3b79ff96e0853bb | |
parent | ef5d53ed7e3c7d30481a88301f680e37a5cc4775 (diff) | |
download | syzkaller-03820adaef911ce08278d95f034f134c3c0c852e.tar.gz |
pkg/fuzzer: use queue layers
Instead of relying on a fuzzer-internal priority queue, utilize
stackable layers of request-generating steps.
Move the functionality to a separate pkg/fuzzer/queue package.
The pkg/fuzzer/queue package can be reused to add extra processing
layers on top of the fuzzing and to combine machine checking and fuzzing
execution pipelines.
-rw-r--r-- | pkg/fuzzer/fuzzer.go | 243 | ||||
-rw-r--r-- | pkg/fuzzer/fuzzer_test.go | 23 | ||||
-rw-r--r-- | pkg/fuzzer/job.go | 135 | ||||
-rw-r--r-- | pkg/fuzzer/job_test.go | 9 | ||||
-rw-r--r-- | pkg/fuzzer/prio_queue_test.go | 59 | ||||
-rw-r--r-- | pkg/fuzzer/queue/prio_queue.go (renamed from pkg/fuzzer/prio_queue.go) | 36 | ||||
-rw-r--r-- | pkg/fuzzer/queue/prio_queue_test.go | 40 | ||||
-rw-r--r-- | pkg/fuzzer/queue/queue.go | 270 | ||||
-rw-r--r-- | pkg/fuzzer/queue/queue_test.go | 54 | ||||
-rw-r--r-- | syz-manager/rpc.go | 30 |
10 files changed, 575 insertions, 324 deletions
diff --git a/pkg/fuzzer/fuzzer.go b/pkg/fuzzer/fuzzer.go index 7c8b63bab..9d8957922 100644 --- a/pkg/fuzzer/fuzzer.go +++ b/pkg/fuzzer/fuzzer.go @@ -9,12 +9,11 @@ import ( "math/rand" "runtime" "sync" - "sync/atomic" "time" "github.com/google/syzkaller/pkg/corpus" + "github.com/google/syzkaller/pkg/fuzzer/queue" "github.com/google/syzkaller/pkg/ipc" - "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/pkg/stats" "github.com/google/syzkaller/prog" ) @@ -34,8 +33,7 @@ type Fuzzer struct { ctMu sync.Mutex // TODO: use RWLock. ctRegenerate chan struct{} - nextExec *priorityQueue[*Request] - nextJobID atomic.Int64 + execQueues } func NewFuzzer(ctx context.Context, cfg *Config, rnd *rand.Rand, @@ -57,9 +55,8 @@ func NewFuzzer(ctx context.Context, cfg *Config, rnd *rand.Rand, // We're okay to lose some of the messages -- if we are already // regenerating the table, we don't want to repeat it right away. ctRegenerate: make(chan struct{}), - - nextExec: makePriorityQueue[*Request](), } + f.execQueues = newExecQueues(f) f.updateChoiceTable(nil) go f.choiceTableUpdater() if cfg.Debug { @@ -68,67 +65,105 @@ func NewFuzzer(ctx context.Context, cfg *Config, rnd *rand.Rand, return f } -type Config struct { - Debug bool - Corpus *corpus.Corpus - Logf func(level int, msg string, args ...interface{}) - Coverage bool - FaultInjection bool - Comparisons bool - Collide bool - EnabledCalls map[*prog.Syscall]bool - NoMutateCalls map[int]bool - FetchRawCover bool - NewInputFilter func(call string) bool +type execQueues struct { + smashQueue *queue.PlainQueue + triageQueue *queue.PriorityQueue + candidateQueue *queue.PlainQueue + triageCandidateQueue *queue.PriorityQueue + source queue.Source } -type Request struct { - Prog *prog.Prog - NeedSignal SignalType - NeedCover bool - NeedHints bool - // If specified, the resulting signal for call SignalFilterCall - // will include subset of it even if it's not new. - SignalFilter signal.Signal - SignalFilterCall int - // Fields that are only relevant within pkg/fuzzer. - flags ProgTypes - stat *stats.Val - resultC chan *Result +func newExecQueues(fuzzer *Fuzzer) execQueues { + ret := execQueues{ + triageCandidateQueue: queue.Priority(), + candidateQueue: queue.PlainWithStat(fuzzer.StatCandidates), + triageQueue: queue.Priority(), + smashQueue: queue.Plain(), + } + // Sources are listed in the order, in which they will be polled. + ret.source = queue.Order( + ret.triageCandidateQueue, + ret.candidateQueue, + ret.triageQueue, + // Alternate smash jobs with exec/fuzz once in 3 times. + queue.Alternate(ret.smashQueue, 3), + queue.Callback(fuzzer.genFuzz), + ) + return ret } -type SignalType int +type execOpt any +type dontTriage struct{} +type progFlags ProgTypes -const ( - NoSignal SignalType = iota // we don't need any signal - NewSignal // we need the newly seen signal - AllSignal // we need all signal -) +func (fuzzer *Fuzzer) validateRequest(req *queue.Request) { + if req.NeedHints && (req.NeedCover || req.NeedSignal != queue.NoSignal) { + panic("Request.NeedHints is mutually exclusive with other fields") + } + if req.SignalFilter != nil && req.NeedSignal != queue.NewSignal { + panic("SignalFilter must be used with NewSignal") + } +} -type Result struct { - Info *ipc.ProgInfo - Stop bool +func (fuzzer *Fuzzer) execute(executor queue.Executor, req *queue.Request, opts ...execOpt) *queue.Result { + fuzzer.validateRequest(req) + executor.Submit(req) + res := req.Wait(fuzzer.ctx) + fuzzer.processResult(req, res, opts...) + return res } -func (fuzzer *Fuzzer) Done(req *Request, res *Result) { +func (fuzzer *Fuzzer) prepare(req *queue.Request, opts ...execOpt) { + fuzzer.validateRequest(req) + req.OnDone(func(req *queue.Request, res *queue.Result) bool { + fuzzer.processResult(req, res, opts...) + return true + }) +} + +func (fuzzer *Fuzzer) enqueue(executor queue.Executor, req *queue.Request, opts ...execOpt) { + fuzzer.prepare(req, opts...) + executor.Submit(req) +} + +func (fuzzer *Fuzzer) processResult(req *queue.Request, res *queue.Result, opts ...execOpt) { + var flags ProgTypes + var noTriage bool + for _, opt := range opts { + switch v := opt.(type) { + case progFlags: + flags = ProgTypes(v) + case dontTriage: + noTriage = true + } + } // Triage individual calls. // We do it before unblocking the waiting threads because // it may result it concurrent modification of req.Prog. // If we are already triaging this exact prog, this is flaky coverage. - if req.NeedSignal != NoSignal && res.Info != nil && req.flags&progInTriage == 0 { + if req.NeedSignal != queue.NoSignal && res.Info != nil && !noTriage { for call, info := range res.Info.Calls { - fuzzer.triageProgCall(req.Prog, &info, call, req.flags) + fuzzer.triageProgCall(req.Prog, &info, call, flags) } - fuzzer.triageProgCall(req.Prog, &res.Info.Extra, -1, req.flags) - } - // Unblock threads that wait for the result. - if req.resultC != nil { - req.resultC <- res + fuzzer.triageProgCall(req.Prog, &res.Info.Extra, -1, flags) } if res.Info != nil { fuzzer.statExecTime.Add(int(res.Info.Elapsed.Milliseconds())) } - req.stat.Add(1) +} + +type Config struct { + Debug bool + Corpus *corpus.Corpus + Logf func(level int, msg string, args ...interface{}) + Coverage bool + FaultInjection bool + Comparisons bool + Collide bool + EnabledCalls map[*prog.Syscall]bool + NoMutateCalls map[int]bool + FetchRawCover bool + NewInputFilter func(call string) bool } func (fuzzer *Fuzzer) triageProgCall(p *prog.Prog, info *ipc.CallInfo, call int, flags ProgTypes) { @@ -141,13 +176,18 @@ func (fuzzer *Fuzzer) triageProgCall(p *prog.Prog, info *ipc.CallInfo, call int, return } fuzzer.Logf(2, "found new signal in call %d in %s", call, p) + + queue := fuzzer.triageQueue + if flags&progCandidate > 0 { + queue = fuzzer.triageCandidateQueue + } fuzzer.startJob(fuzzer.statJobsTriage, &triageJob{ - p: p.Clone(), - call: call, - info: *info, - newSignal: newMaxSignal, - flags: flags, - jobPriority: triageJobPrio(flags), + p: p.Clone(), + call: call, + info: *info, + newSignal: newMaxSignal, + flags: flags, + queue: queue.AppendQueue(), }) } @@ -164,33 +204,7 @@ func signalPrio(p *prog.Prog, info *ipc.CallInfo, call int) (prio uint8) { return } -type Candidate struct { - Prog *prog.Prog - Smashed bool - Minimized bool -} - -func (fuzzer *Fuzzer) NextInput() *Request { - req := fuzzer.nextInput() - if req.stat == fuzzer.statExecCandidate { - fuzzer.StatCandidates.Add(-1) - } - return req -} - -func (fuzzer *Fuzzer) nextInput() *Request { - nextExec := fuzzer.nextExec.tryPop() - - // The fuzzer may become too interested in potentially very long hint and smash jobs. - // Let's leave more space for new input space exploration. - if nextExec != nil { - if nextExec.prio.greaterThan(priority{smashPrio}) || fuzzer.nextRand()%3 != 0 { - return nextExec.value - } else { - fuzzer.nextExec.push(nextExec) - } - } - +func (fuzzer *Fuzzer) genFuzz() *queue.Request { // Either generate a new input or mutate an existing one. mutateRate := 0.95 if !fuzzer.Config.Coverage { @@ -198,23 +212,20 @@ func (fuzzer *Fuzzer) nextInput() *Request { // more frequently because fallback signal is weak. mutateRate = 0.5 } + var req *queue.Request rnd := fuzzer.rand() if rnd.Float64() < mutateRate { - req := mutateProgRequest(fuzzer, rnd) - if req != nil { - return req - } + req = mutateProgRequest(fuzzer, rnd) + } + if req == nil { + req = genProgRequest(fuzzer, rnd) } - return genProgRequest(fuzzer, rnd) + fuzzer.prepare(req) + return req } func (fuzzer *Fuzzer) startJob(stat *stats.Val, newJob job) { fuzzer.Logf(2, "started %T", newJob) - if impl, ok := newJob.(jobSaveID); ok { - // E.g. for big and slow hint jobs, we would prefer not to serialize them, - // but rather to start them all in parallel. - impl.saveID(-fuzzer.nextJobID.Add(1)) - } go func() { stat.Add(1) fuzzer.statJobs.Add(1) @@ -224,6 +235,10 @@ func (fuzzer *Fuzzer) startJob(stat *stats.Val, newJob job) { }() } +func (fuzzer *Fuzzer) Next() *queue.Request { + return fuzzer.source.Next() +} + func (fuzzer *Fuzzer) Logf(level int, msg string, args ...interface{}) { if fuzzer.Config.Logf == nil { return @@ -231,45 +246,23 @@ func (fuzzer *Fuzzer) Logf(level int, msg string, args ...interface{}) { fuzzer.Config.Logf(level, msg, args...) } +type Candidate struct { + Prog *prog.Prog + Smashed bool + Minimized bool +} + func (fuzzer *Fuzzer) AddCandidates(candidates []Candidate) { - fuzzer.StatCandidates.Add(len(candidates)) for _, candidate := range candidates { - fuzzer.pushExec(candidateRequest(fuzzer, candidate), priority{candidatePrio}) + req, flags := candidateRequest(fuzzer, candidate) + fuzzer.enqueue(fuzzer.candidateQueue, req, progFlags(flags)) } } func (fuzzer *Fuzzer) rand() *rand.Rand { - return rand.New(rand.NewSource(fuzzer.nextRand())) -} - -func (fuzzer *Fuzzer) nextRand() int64 { fuzzer.mu.Lock() defer fuzzer.mu.Unlock() - return fuzzer.rnd.Int63() -} - -func (fuzzer *Fuzzer) pushExec(req *Request, prio priority) { - if req.NeedHints && (req.NeedCover || req.NeedSignal != NoSignal) { - panic("Request.NeedHints is mutually exclusive with other fields") - } - if req.SignalFilter != nil && req.NeedSignal != NewSignal { - panic("SignalFilter must be used with NewSignal") - } - fuzzer.nextExec.push(&priorityQueueItem[*Request]{ - value: req, prio: prio, - }) -} - -func (fuzzer *Fuzzer) exec(job job, req *Request) *Result { - req.resultC = make(chan *Result, 1) - fuzzer.pushExec(req, job.priority()) - select { - case <-fuzzer.ctx.Done(): - return &Result{Stop: true} - case res := <-req.resultC: - close(req.resultC) - return res - } + return rand.New(rand.NewSource(fuzzer.rnd.Int63())) } func (fuzzer *Fuzzer) updateChoiceTable(programs []*prog.Prog) { @@ -327,8 +320,8 @@ func (fuzzer *Fuzzer) logCurrentStats() { var m runtime.MemStats runtime.ReadMemStats(&m) - str := fmt.Sprintf("exec queue size: %d, running jobs: %d, heap (MB): %d", - fuzzer.nextExec.Len(), fuzzer.statJobs.Val(), m.Alloc/1000/1000) + str := fmt.Sprintf("running jobs: %d, heap (MB): %d", + fuzzer.statJobs.Val(), m.Alloc/1000/1000) fuzzer.Logf(0, "%s", str) } } diff --git a/pkg/fuzzer/fuzzer_test.go b/pkg/fuzzer/fuzzer_test.go index 13eb5609b..ab38fa783 100644 --- a/pkg/fuzzer/fuzzer_test.go +++ b/pkg/fuzzer/fuzzer_test.go @@ -18,6 +18,7 @@ import ( "github.com/google/syzkaller/pkg/corpus" "github.com/google/syzkaller/pkg/csource" + "github.com/google/syzkaller/pkg/fuzzer/queue" "github.com/google/syzkaller/pkg/ipc" "github.com/google/syzkaller/pkg/ipc/ipcconfig" "github.com/google/syzkaller/pkg/signal" @@ -111,9 +112,9 @@ func BenchmarkFuzzer(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - req := fuzzer.NextInput() + req := fuzzer.Next() res, _, _ := emulateExec(req) - fuzzer.Done(req, res) + req.Done(res) } }) } @@ -180,7 +181,7 @@ func TestRotate(t *testing.T) { // Based on the example from Go documentation. var crc32q = crc32.MakeTable(0xD5828281) -func emulateExec(req *Request) (*Result, string, error) { +func emulateExec(req *queue.Request) (*queue.Result, string, error) { serializedLines := bytes.Split(req.Prog.Serialize(), []byte("\n")) var info ipc.ProgInfo for i, call := range req.Prog.Calls { @@ -190,12 +191,12 @@ func emulateExec(req *Request) (*Result, string, error) { if req.NeedCover { callInfo.Cover = []uint32{cover} } - if req.NeedSignal != NoSignal { + if req.NeedSignal != queue.NoSignal { callInfo.Signal = []uint32{cover} } info.Calls = append(info.Calls, callInfo) } - return &Result{Info: &info}, "", nil + return &queue.Result{Info: &info}, "", nil } type testFuzzer struct { @@ -235,13 +236,13 @@ func (f *testFuzzer) oneMore() bool { func (f *testFuzzer) registerExecutor(proc *executorProc) { f.eg.Go(func() error { for f.oneMore() { - req := f.fuzzer.NextInput() + req := f.fuzzer.Next() res, crash, err := proc.execute(req) if err != nil { return err } if crash != "" { - res = &Result{Stop: true} + res = &queue.Result{Stop: true} if !f.expectedCrashes[crash] { return fmt.Errorf("unexpected crash: %q", crash) } @@ -250,7 +251,7 @@ func (f *testFuzzer) registerExecutor(proc *executorProc) { f.crashes[crash]++ f.mu.Unlock() } - f.fuzzer.Done(req, res) + req.Done(res) } return nil }) @@ -296,10 +297,10 @@ func newProc(t *testing.T, target *prog.Target, executor string) *executorProc { var crashRe = regexp.MustCompile(`{{CRASH: (.*?)}}`) -func (proc *executorProc) execute(req *Request) (*Result, string, error) { +func (proc *executorProc) execute(req *queue.Request) (*queue.Result, string, error) { execOpts := proc.execOpts // TODO: it's duplicated from fuzzer.go. - if req.NeedSignal != NoSignal { + if req.NeedSignal != queue.NoSignal { execOpts.ExecFlags |= ipc.FlagCollectSignal } if req.NeedCover { @@ -313,7 +314,7 @@ func (proc *executorProc) execute(req *Request) (*Result, string, error) { } else if err != nil { return nil, "", err } - return &Result{Info: info}, "", nil + return &queue.Result{Info: info}, "", nil } func checkGoroutineLeaks() { diff --git a/pkg/fuzzer/job.go b/pkg/fuzzer/job.go index 5663b6723..8f81ef9fa 100644 --- a/pkg/fuzzer/job.go +++ b/pkg/fuzzer/job.go @@ -9,27 +9,15 @@ import ( "github.com/google/syzkaller/pkg/corpus" "github.com/google/syzkaller/pkg/cover" + "github.com/google/syzkaller/pkg/fuzzer/queue" "github.com/google/syzkaller/pkg/ipc" "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/pkg/stats" "github.com/google/syzkaller/prog" ) -const ( - smashPrio int64 = iota + 1 - genPrio - triagePrio - candidatePrio - candidateTriagePrio -) - type job interface { run(fuzzer *Fuzzer) - priority() priority -} - -type jobSaveID interface { - saveID(id int64) } type ProgTypes int @@ -38,44 +26,20 @@ const ( progCandidate ProgTypes = 1 << iota progMinimized progSmashed - progInTriage ) -type jobPriority struct { - prio priority -} - -var _ jobSaveID = new(jobPriority) - -func newJobPriority(base int64) jobPriority { - prio := append(make(priority, 0, 2), base) - return jobPriority{prio} -} - -func (jp jobPriority) priority() priority { - return jp.prio -} - -// If we prioritize execution requests only by the base priorities of their origin -// jobs, we risk letting 1000s of simultaneous jobs slowly progress in parallel. -// It's better to let same-prio jobs that were started earlier finish first. -// saveID() allows Fuzzer to attach this sub-prio at the moment of job creation. -func (jp *jobPriority) saveID(id int64) { - jp.prio = append(jp.prio, id) -} - -func genProgRequest(fuzzer *Fuzzer, rnd *rand.Rand) *Request { +func genProgRequest(fuzzer *Fuzzer, rnd *rand.Rand) *queue.Request { p := fuzzer.target.Generate(rnd, prog.RecommendedCalls, fuzzer.ChoiceTable()) - return &Request{ + return &queue.Request{ Prog: p, - NeedSignal: NewSignal, - stat: fuzzer.statExecGenerate, + NeedSignal: queue.NewSignal, + Stat: fuzzer.statExecGenerate, } } -func mutateProgRequest(fuzzer *Fuzzer, rnd *rand.Rand) *Request { +func mutateProgRequest(fuzzer *Fuzzer, rnd *rand.Rand) *queue.Request { p := fuzzer.Config.Corpus.ChooseProgram(rnd) if p == nil { return nil @@ -87,14 +51,14 @@ func mutateProgRequest(fuzzer *Fuzzer, rnd *rand.Rand) *Request { fuzzer.Config.NoMutateCalls, fuzzer.Config.Corpus.Programs(), ) - return &Request{ + return &queue.Request{ Prog: newP, - NeedSignal: NewSignal, - stat: fuzzer.statExecFuzz, + NeedSignal: queue.NewSignal, + Stat: fuzzer.statExecFuzz, } } -func candidateRequest(fuzzer *Fuzzer, input Candidate) *Request { +func candidateRequest(fuzzer *Fuzzer, input Candidate) (*queue.Request, ProgTypes) { flags := progCandidate if input.Minimized { flags |= progMinimized @@ -102,12 +66,11 @@ func candidateRequest(fuzzer *Fuzzer, input Candidate) *Request { if input.Smashed { flags |= progSmashed } - return &Request{ + return &queue.Request{ Prog: input.Prog, - NeedSignal: NewSignal, - stat: fuzzer.statExecCandidate, - flags: flags, - } + NeedSignal: queue.NewSignal, + Stat: fuzzer.statExecCandidate, + }, flags } // triageJob are programs for which we noticed potential new coverage during @@ -120,27 +83,28 @@ type triageJob struct { info ipc.CallInfo newSignal signal.Signal flags ProgTypes - jobPriority + fuzzer *Fuzzer + queue queue.Executor } -func triageJobPrio(flags ProgTypes) jobPriority { - if flags&progCandidate > 0 { - return newJobPriority(candidateTriagePrio) - } - return newJobPriority(triagePrio) +func (job *triageJob) execute(req *queue.Request, opts ...execOpt) *queue.Result { + return job.fuzzer.execute(job.queue, req, opts...) } func (job *triageJob) run(fuzzer *Fuzzer) { fuzzer.statNewInputs.Add(1) + job.fuzzer = fuzzer + callName := fmt.Sprintf("call #%v %v", job.call, job.p.CallName(job.call)) fuzzer.Logf(3, "triaging input for %v (new signal=%v)", callName, job.newSignal.Len()) + // Compute input coverage and non-flaky signal for minimization. - info, stop := job.deflake(fuzzer.exec, fuzzer.statExecTriage, fuzzer.Config.FetchRawCover) + info, stop := job.deflake(job.execute, fuzzer.statExecTriage, fuzzer.Config.FetchRawCover) if stop || info.newStableSignal.Empty() { return } if job.flags&progMinimized == 0 { - stop = job.minimize(fuzzer, info.newStableSignal) + stop = job.minimize(info.newStableSignal) if stop { return } @@ -172,8 +136,8 @@ type deflakedCover struct { rawCover []uint32 } -func (job *triageJob) deflake(exec func(job, *Request) *Result, stat *stats.Val, rawCover bool) ( - info deflakedCover, stop bool) { +func (job *triageJob) deflake(exec func(*queue.Request, ...execOpt) *queue.Result, stat *stats.Val, + rawCover bool) (info deflakedCover, stop bool) { // As demonstrated in #4639, programs reproduce with a very high, but not 100% probability. // The triage algorithm must tolerate this, so let's pick the signal that is common // to 3 out of 5 runs. @@ -194,13 +158,12 @@ func (job *triageJob) deflake(exec func(job, *Request) *Result, stat *stats.Val, // There's no chance to get coverage common to needRuns. break } - result := exec(job, &Request{ + result := exec(&queue.Request{ Prog: job.p, - NeedSignal: AllSignal, + NeedSignal: queue.AllSignal, NeedCover: true, - stat: stat, - flags: progInTriage, - }) + Stat: stat, + }, &dontTriage{}) if result.Stop { stop = true return @@ -226,7 +189,7 @@ func (job *triageJob) deflake(exec func(job, *Request) *Result, stat *stats.Val, return } -func (job *triageJob) minimize(fuzzer *Fuzzer, newSignal signal.Signal) (stop bool) { +func (job *triageJob) minimize(newSignal signal.Signal) (stop bool) { const minimizeAttempts = 3 job.p, job.call = prog.Minimize(job.p, job.call, false, func(p1 *prog.Prog, call1 int) bool { @@ -234,12 +197,12 @@ func (job *triageJob) minimize(fuzzer *Fuzzer, newSignal signal.Signal) (stop bo return false } for i := 0; i < minimizeAttempts; i++ { - result := fuzzer.exec(job, &Request{ + result := job.execute(&queue.Request{ Prog: p1, - NeedSignal: NewSignal, + NeedSignal: queue.NewSignal, SignalFilter: newSignal, SignalFilterCall: call1, - stat: fuzzer.statExecMinimize, + Stat: job.fuzzer.statExecMinimize, }) if result.Stop { stop = true @@ -288,10 +251,6 @@ type smashJob struct { call int } -func (job *smashJob) priority() priority { - return priority{smashPrio} -} - func (job *smashJob) run(fuzzer *Fuzzer) { fuzzer.Logf(2, "smashing the program %s (call=%d):", job.p, job.call) if fuzzer.Config.Comparisons && job.call >= 0 { @@ -309,18 +268,18 @@ func (job *smashJob) run(fuzzer *Fuzzer) { fuzzer.ChoiceTable(), fuzzer.Config.NoMutateCalls, fuzzer.Config.Corpus.Programs()) - result := fuzzer.exec(job, &Request{ + result := fuzzer.execute(fuzzer.smashQueue, &queue.Request{ Prog: p, - NeedSignal: NewSignal, - stat: fuzzer.statExecSmash, + NeedSignal: queue.NewSignal, + Stat: fuzzer.statExecSmash, }) if result.Stop { return } if fuzzer.Config.Collide { - result := fuzzer.exec(job, &Request{ + result := fuzzer.execute(fuzzer.smashQueue, &queue.Request{ Prog: randomCollide(p, rnd), - stat: fuzzer.statExecCollide, + Stat: fuzzer.statExecCollide, }) if result.Stop { return @@ -360,9 +319,9 @@ func (job *smashJob) faultInjection(fuzzer *Fuzzer) { job.call, nth) newProg := job.p.Clone() newProg.Calls[job.call].Props.FailNth = nth - result := fuzzer.exec(job, &Request{ + result := fuzzer.execute(fuzzer.smashQueue, &queue.Request{ Prog: newProg, - stat: fuzzer.statExecSmash, + Stat: fuzzer.statExecSmash, }) if result.Stop { return @@ -380,20 +339,16 @@ type hintsJob struct { call int } -func (job *hintsJob) priority() priority { - return priority{smashPrio} -} - func (job *hintsJob) run(fuzzer *Fuzzer) { // First execute the original program twice to get comparisons from KCOV. // The second execution lets us filter out flaky values, which seem to constitute ~30-40%. p := job.p var comps prog.CompMap for i := 0; i < 2; i++ { - result := fuzzer.exec(job, &Request{ + result := fuzzer.execute(fuzzer.smashQueue, &queue.Request{ Prog: p, NeedHints: true, - stat: fuzzer.statExecSeed, + Stat: fuzzer.statExecSeed, }) if result.Stop || result.Info == nil { return @@ -413,10 +368,10 @@ func (job *hintsJob) run(fuzzer *Fuzzer) { // Execute each of such mutants to check if it gives new coverage. p.MutateWithHints(job.call, comps, func(p *prog.Prog) bool { - result := fuzzer.exec(job, &Request{ + result := fuzzer.execute(fuzzer.smashQueue, &queue.Request{ Prog: p, - NeedSignal: NewSignal, - stat: fuzzer.statExecHint, + NeedSignal: queue.NewSignal, + Stat: fuzzer.statExecHint, }) return !result.Stop }) diff --git a/pkg/fuzzer/job_test.go b/pkg/fuzzer/job_test.go index cdf46186d..59f0e7097 100644 --- a/pkg/fuzzer/job_test.go +++ b/pkg/fuzzer/job_test.go @@ -6,6 +6,7 @@ package fuzzer import ( "testing" + "github.com/google/syzkaller/pkg/fuzzer/queue" "github.com/google/syzkaller/pkg/ipc" "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/prog" @@ -28,7 +29,7 @@ func TestDeflakeFail(t *testing.T) { } run := 0 - ret, stop := testJob.deflake(func(_ job, _ *Request) *Result { + ret, stop := testJob.deflake(func(_ *queue.Request, _ ...execOpt) *queue.Result { run++ // For first, we return 0 and 1. For second, 1 and 2. And so on. return fakeResult(0, []uint32{uint32(run), uint32(run + 1)}, []uint32{10, 20}) @@ -53,7 +54,7 @@ func TestDeflakeSuccess(t *testing.T) { newSignal: signal.FromRaw([]uint32{0, 1, 2}, 0), } run := 0 - ret, stop := testJob.deflake(func(_ job, _ *Request) *Result { + ret, stop := testJob.deflake(func(_ *queue.Request, _ ...execOpt) *queue.Result { run++ switch run { case 1: @@ -79,8 +80,8 @@ func TestDeflakeSuccess(t *testing.T) { assert.ElementsMatch(t, []uint32{0, 2}, ret.newStableSignal.ToRaw()) } -func fakeResult(errno int, signal, cover []uint32) *Result { - return &Result{ +func fakeResult(errno int, signal, cover []uint32) *queue.Result { + return &queue.Result{ Info: &ipc.ProgInfo{ Calls: []ipc.CallInfo{ { diff --git a/pkg/fuzzer/prio_queue_test.go b/pkg/fuzzer/prio_queue_test.go deleted file mode 100644 index 3b5b87105..000000000 --- a/pkg/fuzzer/prio_queue_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2024 syzkaller project authors. All rights reserved. -// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. - -package fuzzer - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" -) - -func TestPriority(t *testing.T) { - assert.True(t, priority{1, 2}.greaterThan(priority{1, 1})) - assert.True(t, priority{3, 2}.greaterThan(priority{2, 3})) - assert.True(t, priority{1, -5}.greaterThan(priority{1, -10})) - assert.True(t, priority{1}.greaterThan(priority{1, -1})) - assert.False(t, priority{1}.greaterThan(priority{1, 1})) -} - -func TestPrioQueueOrder(t *testing.T) { - pq := makePriorityQueue[int]() - assert.Nil(t, pq.tryPop()) - - pq.push(&priorityQueueItem[int]{value: 1, prio: priority{1}}) - pq.push(&priorityQueueItem[int]{value: 3, prio: priority{3}}) - pq.push(&priorityQueueItem[int]{value: 2, prio: priority{2}}) - - assert.Equal(t, 3, pq.tryPop().value) - assert.Equal(t, 2, pq.tryPop().value) - assert.Equal(t, 1, pq.tryPop().value) - assert.Nil(t, pq.tryPop()) - assert.Zero(t, pq.Len()) -} - -func TestPrioQueueRace(t *testing.T) { - var eg errgroup.Group - pq := makePriorityQueue[int]() - - // Two writers. - for writer := 0; writer < 2; writer++ { - eg.Go(func() error { - for i := 0; i < 1000; i++ { - pq.push(&priorityQueueItem[int]{value: 10, prio: priority{1}}) - } - return nil - }) - } - // Two readers. - for reader := 0; reader < 2; reader++ { - eg.Go(func() error { - for i := 0; i < 1000; i++ { - pq.tryPop() - } - return nil - }) - } - eg.Wait() -} diff --git a/pkg/fuzzer/prio_queue.go b/pkg/fuzzer/queue/prio_queue.go index c67b6216c..a71afe61d 100644 --- a/pkg/fuzzer/prio_queue.go +++ b/pkg/fuzzer/queue/prio_queue.go @@ -1,11 +1,10 @@ // Copyright 2024 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. -package fuzzer +package queue import ( "container/heap" - "sync" ) type priority []int64 @@ -30,34 +29,33 @@ func (p priority) greaterThan(other priority) bool { return false } -type priorityQueue[T any] struct { - impl priorityQueueImpl[T] - mu sync.RWMutex +func (p priority) next() priority { + if len(p) == 0 { + return p + } + newPrio := append([]int64{}, p...) + newPrio[len(newPrio)-1]-- + return newPrio } -func makePriorityQueue[T any]() *priorityQueue[T] { - return &priorityQueue[T]{} +type priorityQueueOps[T any] struct { + impl priorityQueueImpl[T] } -func (pq *priorityQueue[T]) Len() int { - pq.mu.RLock() - defer pq.mu.RUnlock() +func (pq *priorityQueueOps[T]) Len() int { return pq.impl.Len() } -func (pq *priorityQueue[T]) push(item *priorityQueueItem[T]) { - pq.mu.Lock() - defer pq.mu.Unlock() - heap.Push(&pq.impl, item) +func (pq *priorityQueueOps[T]) Push(item T, prio priority) { + heap.Push(&pq.impl, &priorityQueueItem[T]{item, prio}) } -func (pq *priorityQueue[T]) tryPop() *priorityQueueItem[T] { - pq.mu.Lock() - defer pq.mu.Unlock() +func (pq *priorityQueueOps[T]) Pop() T { if len(pq.impl) == 0 { - return nil + var def T + return def } - return heap.Pop(&pq.impl).(*priorityQueueItem[T]) + return heap.Pop(&pq.impl).(*priorityQueueItem[T]).value } // The implementation below is based on the example provided diff --git a/pkg/fuzzer/queue/prio_queue_test.go b/pkg/fuzzer/queue/prio_queue_test.go new file mode 100644 index 000000000..a82886bdd --- /dev/null +++ b/pkg/fuzzer/queue/prio_queue_test.go @@ -0,0 +1,40 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package queue + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNextPriority(t *testing.T) { + first := priority{0} + second := first.next() + third := second.next() + assert.True(t, first.greaterThan(second)) + assert.True(t, second.greaterThan(third)) +} + +func TestPriority(t *testing.T) { + assert.True(t, priority{1, 2}.greaterThan(priority{1, 1})) + assert.True(t, priority{3, 2}.greaterThan(priority{2, 3})) + assert.True(t, priority{1, -5}.greaterThan(priority{1, -10})) + assert.True(t, priority{1}.greaterThan(priority{1, -1})) + assert.False(t, priority{1}.greaterThan(priority{1, 1})) + assert.True(t, priority{1, 0}.greaterThan(priority{1})) +} + +func TestPrioQueueOrder(t *testing.T) { + pq := priorityQueueOps[int]{} + pq.Push(1, priority{1}) + pq.Push(3, priority{3}) + pq.Push(2, priority{2}) + + assert.Equal(t, 3, pq.Pop()) + assert.Equal(t, 2, pq.Pop()) + assert.Equal(t, 1, pq.Pop()) + assert.Zero(t, pq.Pop()) + assert.Zero(t, pq.Len()) +} diff --git a/pkg/fuzzer/queue/queue.go b/pkg/fuzzer/queue/queue.go new file mode 100644 index 000000000..00e83a69e --- /dev/null +++ b/pkg/fuzzer/queue/queue.go @@ -0,0 +1,270 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package queue + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/google/syzkaller/pkg/ipc" + "github.com/google/syzkaller/pkg/signal" + "github.com/google/syzkaller/pkg/stats" + "github.com/google/syzkaller/prog" +) + +type Request struct { + Prog *prog.Prog + NeedSignal SignalType + NeedCover bool + NeedHints bool + // If specified, the resulting signal for call SignalFilterCall + // will include subset of it even if it's not new. + SignalFilter signal.Signal + SignalFilterCall int + + // This stat will be incremented on request completion. + Stat *stats.Val + + // The callback will be called on request completion in the LIFO order. + // If it returns false, all further processing will be stopped. + // It allows wrappers to intercept Done() requests. + callback DoneCallback + + mu sync.Mutex + result *Result + done chan struct{} +} + +type DoneCallback func(*Request, *Result) bool + +func (r *Request) OnDone(cb DoneCallback) { + oldCallback := r.callback + r.callback = func(req *Request, res *Result) bool { + r.callback = oldCallback + if !cb(req, res) { + return false + } + if oldCallback == nil { + return true + } + return oldCallback(req, res) + } +} + +func (r *Request) Done(res *Result) { + if r.callback != nil { + if !r.callback(r, res) { + return + } + } + if r.Stat != nil { + r.Stat.Add(1) + } + r.initChannel() + r.result = res + close(r.done) +} + +// Wait() blocks until we have the result. +func (r *Request) Wait(ctx context.Context) *Result { + r.initChannel() + select { + case <-ctx.Done(): + return &Result{Stop: true} + case <-r.done: + return r.result + } +} + +func (r *Request) initChannel() { + r.mu.Lock() + if r.done == nil { + r.done = make(chan struct{}) + } + r.mu.Unlock() +} + +type SignalType int + +const ( + NoSignal SignalType = iota // we don't need any signal + NewSignal // we need the newly seen signal + AllSignal // we need all signal +) + +type Result struct { + Info *ipc.ProgInfo + Stop bool +} + +// Executor describes the interface wanted by the producers of requests. +// After a Request is submitted, it's expected that the consumer will eventually +// take it and report the execution result via Done(). +type Executor interface { + Submit(req *Request) +} + +// Source describes the interface wanted by the consumers of requests. +type Source interface { + Next() *Request +} + +// PlainQueue is a straighforward thread-safe Request queue implementation. +type PlainQueue struct { + stat *stats.Val + mu sync.Mutex + queue []*Request + pos int +} + +func Plain() *PlainQueue { + return &PlainQueue{} +} + +func PlainWithStat(val *stats.Val) *PlainQueue { + return &PlainQueue{stat: val} +} + +func (pq *PlainQueue) Len() int { + pq.mu.Lock() + defer pq.mu.Unlock() + return len(pq.queue) - pq.pos +} + +func (pq *PlainQueue) Submit(req *Request) { + if pq.stat != nil { + pq.stat.Add(1) + } + pq.mu.Lock() + defer pq.mu.Unlock() + + // It doesn't make sense to compact the queue too often. + const minSizeToCompact = 128 + if pq.pos > len(pq.queue)/2 && len(pq.queue) >= minSizeToCompact { + copy(pq.queue, pq.queue[pq.pos:]) + for pq.pos > 0 { + newLen := len(pq.queue) - 1 + pq.queue[newLen] = nil + pq.queue = pq.queue[:newLen] + pq.pos-- + } + } + pq.queue = append(pq.queue, req) +} + +func (pq *PlainQueue) Next() *Request { + pq.mu.Lock() + defer pq.mu.Unlock() + if pq.pos < len(pq.queue) { + ret := pq.queue[pq.pos] + pq.queue[pq.pos] = nil + pq.pos++ + if pq.stat != nil { + pq.stat.Add(-1) + } + return ret + } + return nil +} + +// Order combines several different sources in a particular order. +type orderImpl struct { + sources []Source +} + +func Order(sources ...Source) Source { + return &orderImpl{sources: sources} +} + +func (o *orderImpl) Next() *Request { + for _, s := range o.sources { + req := s.Next() + if req != nil { + return req + } + } + return nil +} + +type callback struct { + cb func() *Request +} + +// Callback produces a source that calls the callback to serve every Next() request. +func Callback(cb func() *Request) Source { + return &callback{cb} +} + +func (cb *callback) Next() *Request { + return cb.cb() +} + +type alternate struct { + base Source + nth int + seq atomic.Int64 +} + +// Alternate proxies base, but returns nil every nth Next() call. +func Alternate(base Source, nth int) Source { + return &alternate{ + base: base, + nth: nth, + } +} + +func (a *alternate) Next() *Request { + if a.seq.Add(1)%int64(a.nth) == 0 { + return nil + } + return a.base.Next() +} + +type PriorityQueue struct { + mu *sync.Mutex + ops *priorityQueueOps[*Request] + currPrio priority +} + +func Priority() *PriorityQueue { + return &PriorityQueue{ + mu: &sync.Mutex{}, + ops: &priorityQueueOps[*Request]{}, + currPrio: priority{0}, + } +} + +// AppendQueue() can be used to form nested queues. +// That is, if +// q1 := pq.AppendQueue() +// q2 := pq.AppendQueue() +// All elements added via q2.Submit() will always have a *lower* priority +// than all elements added via q1.Submit(). +func (pq *PriorityQueue) AppendQueue() *PriorityQueue { + pq.mu.Lock() + defer pq.mu.Unlock() + pq.currPrio = pq.currPrio.next() + nextPrio := append(priority{}, pq.currPrio...) + return &PriorityQueue{ + // We use the same queue, therefore the same mutex. + mu: pq.mu, + ops: pq.ops, + currPrio: append(nextPrio, 0), + } +} + +// Each subsequent element added via Submit() will have a lower priority. +func (pq *PriorityQueue) Submit(req *Request) { + pq.mu.Lock() + defer pq.mu.Unlock() + pq.currPrio = pq.currPrio.next() + pq.ops.Push(req, pq.currPrio) +} + +func (pq *PriorityQueue) Next() *Request { + pq.mu.Lock() + defer pq.mu.Unlock() + return pq.ops.Pop() +} diff --git a/pkg/fuzzer/queue/queue_test.go b/pkg/fuzzer/queue/queue_test.go new file mode 100644 index 000000000..34a34ccbe --- /dev/null +++ b/pkg/fuzzer/queue/queue_test.go @@ -0,0 +1,54 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package queue + +import ( + "testing" + + "github.com/google/syzkaller/pkg/stats" + "github.com/stretchr/testify/assert" +) + +func TestPlainQueue(t *testing.T) { + val := stats.Create("v0", "desc0") + pq := PlainWithStat(val) + + req1, req2, req3 := &Request{}, &Request{}, &Request{} + + pq.Submit(req1) + assert.Equal(t, 1, val.Val()) + pq.Submit(req2) + assert.Equal(t, 2, val.Val()) + + assert.Equal(t, req1, pq.Next()) + assert.Equal(t, 1, val.Val()) + + assert.Equal(t, req2, pq.Next()) + assert.Equal(t, 0, val.Val()) + + pq.Submit(req3) + assert.Equal(t, 1, val.Val()) + assert.Equal(t, req3, pq.Next()) + assert.Nil(t, pq.Next()) +} + +func TestPrioQueue(t *testing.T) { + req1, req2, req3, req4 := + &Request{}, &Request{}, &Request{}, &Request{} + pq := Priority() + + pq1 := pq.AppendQueue() + pq2 := pq.AppendQueue() + pq3 := pq.AppendQueue() + + pq2.Submit(req2) + pq3.Submit(req3) + pq3.Submit(req4) + pq1.Submit(req1) + + assert.Equal(t, req1, pq.Next()) + assert.Equal(t, req2, pq.Next()) + assert.Equal(t, req3, pq.Next()) + assert.Equal(t, req4, pq.Next()) +} diff --git a/syz-manager/rpc.go b/syz-manager/rpc.go index 5a9074b15..bcfc53991 100644 --- a/syz-manager/rpc.go +++ b/syz-manager/rpc.go @@ -16,6 +16,7 @@ import ( "github.com/google/syzkaller/pkg/cover" "github.com/google/syzkaller/pkg/flatrpc" "github.com/google/syzkaller/pkg/fuzzer" + "github.com/google/syzkaller/pkg/fuzzer/queue" "github.com/google/syzkaller/pkg/ipc" "github.com/google/syzkaller/pkg/log" "github.com/google/syzkaller/pkg/mgrconfig" @@ -55,7 +56,7 @@ type RPCServer struct { // We did not finish these requests because of VM restarts. // They will be eventually given to other VMs. - rescuedInputs []*fuzzer.Request + rescuedInputs []*queue.Request statExecs *stats.Val statExecRetries *stats.Val @@ -87,7 +88,7 @@ type Runner struct { } type Request struct { - req *fuzzer.Request + req *queue.Request serialized []byte try int procID int @@ -369,7 +370,7 @@ func (serv *RPCServer) ExchangeInfo(a *rpctype.ExchangeInfoRequest, r *rpctype.E panic("exchange info call with nil fuzzer") } - appendRequest := func(inp *fuzzer.Request) { + appendRequest := func(inp *queue.Request) { if req, ok := serv.newRequest(runner, inp); ok { r.Requests = append(r.Requests, req) } else { @@ -377,7 +378,7 @@ func (serv *RPCServer) ExchangeInfo(a *rpctype.ExchangeInfoRequest, r *rpctype.E // but so far we don't have a better handling than counting this. // This error is observed a lot on the seeded syz_mount_image calls. serv.statExecBufferTooSmall.Add(1) - fuzzerObj.Done(inp, &fuzzer.Result{Stop: true}) + inp.Done(&queue.Result{Stop: true}) } } @@ -397,11 +398,11 @@ func (serv *RPCServer) ExchangeInfo(a *rpctype.ExchangeInfoRequest, r *rpctype.E // It should foster a more even distribution of executions // across all VMs. for len(r.Requests) < a.NeedProgs { - appendRequest(fuzzerObj.NextInput()) + appendRequest(fuzzerObj.Next()) } for _, result := range a.Results { - serv.doneRequest(runner, result, fuzzerObj) + serv.doneRequest(runner, result) } stats.Import(a.StatsDelta) @@ -477,17 +478,14 @@ func (serv *RPCServer) shutdownInstance(name string, crashed bool) []byte { close(runner.injectStop) - // The VM likely crashed, so let's tell pkg/fuzzer to abort the affected jobs. - // fuzzerObj may be null, but in that case oldRequests would be empty as well. serv.mu.Lock() defer serv.mu.Unlock() if !serv.checkDone.Load() { log.Fatalf("VM is exited while checking is not done") } - fuzzerObj := serv.mgr.getFuzzer() for _, req := range oldRequests { if crashed && req.try >= 0 { - fuzzerObj.Done(req.req, &fuzzer.Result{Stop: true}) + req.req.Done(&queue.Result{Stop: true}) } else { // We will resend these inputs to another VM. serv.rescuedInputs = append(serv.rescuedInputs, req.req) @@ -524,7 +522,7 @@ func (serv *RPCServer) updateCoverFilter(newCover []uint32) { serv.statCoverFiltered.Add(filtered) } -func (serv *RPCServer) doneRequest(runner *Runner, resp rpctype.ExecutionResult, fuzzerObj *fuzzer.Fuzzer) { +func (serv *RPCServer) doneRequest(runner *Runner, resp rpctype.ExecutionResult) { info := &resp.Info if info.Freshness == 0 { serv.statExecutorRestarts.Add(1) @@ -554,10 +552,10 @@ func (serv *RPCServer) doneRequest(runner *Runner, resp rpctype.ExecutionResult, } info.Extra.Cover = runner.instModules.Canonicalize(info.Extra.Cover) info.Extra.Signal = runner.instModules.Canonicalize(info.Extra.Signal) - fuzzerObj.Done(req.req, &fuzzer.Result{Info: info}) + req.req.Done(&queue.Result{Info: info}) } -func (serv *RPCServer) newRequest(runner *Runner, req *fuzzer.Request) (rpctype.ExecutionRequest, bool) { +func (serv *RPCServer) newRequest(runner *Runner, req *queue.Request) (rpctype.ExecutionRequest, bool) { progData, err := req.Prog.SerializeForExec() if err != nil { return rpctype.ExecutionRequest{}, false @@ -587,14 +585,14 @@ func (serv *RPCServer) newRequest(runner *Runner, req *fuzzer.Request) (rpctype. ID: id, ProgData: progData, ExecOpts: serv.createExecOpts(req), - NewSignal: req.NeedSignal == fuzzer.NewSignal, + NewSignal: req.NeedSignal == queue.NewSignal, SignalFilter: signalFilter, SignalFilterCall: req.SignalFilterCall, ResetState: serv.cfg.Experimental.ResetAccState, }, true } -func (serv *RPCServer) createExecOpts(req *fuzzer.Request) ipc.ExecOpts { +func (serv *RPCServer) createExecOpts(req *queue.Request) ipc.ExecOpts { env := ipc.FeaturesToFlags(serv.enabledFeatures, nil) if *flagDebug { env |= ipc.FlagDebug @@ -616,7 +614,7 @@ func (serv *RPCServer) createExecOpts(req *fuzzer.Request) ipc.ExecOpts { exec |= ipc.FlagEnableCoverageFilter } if serv.cfg.Cover { - if req.NeedSignal != fuzzer.NoSignal { + if req.NeedSignal != queue.NoSignal { exec |= ipc.FlagCollectSignal } if req.NeedCover { |