package shortcircuit_test import ( "bytes" "errors" "slices" "testing" "git.codemonkeysoftware.net/b/peachy-go/shortcircuit" "github.com/shoenig/test" "github.com/shoenig/test/must" ) var ErrFakeFailure = errors.New("operation successfully failed") type FailingWriter struct { fail bool failAfter int writeCalls int } func (f *FailingWriter) Fail() { f.FailAfter(0) } func (f *FailingWriter) FailAfter(nBytes int) { f.failAfter = nBytes f.fail = true } func (f *FailingWriter) Succeed() { f.fail = false } func (f *FailingWriter) Write(p []byte) (int, error) { f.writeCalls++ if !f.fail { return len(p), nil } if f.failAfter > len(p) { f.failAfter -= len(p) return len(p), nil } n := f.failAfter f.failAfter = 0 return n, ErrFakeFailure } func (f *FailingWriter) WriteCalls() int { return f.writeCalls } func TestWriter(t *testing.T) { t.Run("writes multiple times and counts total bytes written", func(t *testing.T) { inputs := [][]byte{ []byte("abcdefghi"), []byte("jklmnopqr"), []byte("stuvwxyz"), } var buf bytes.Buffer w := shortcircuit.NewWriter(&buf) for _, b := range inputs { n, err := w.Write(slices.Clone(b)) must.NoError(t, err) must.False(t, w.Failed()) must.EqOp(t, len(b), n, must.Sprint("expected Write to return length of byte slice")) } allInput := slices.Concat(inputs...) test.EqFunc(t, allInput, buf.Bytes(), bytes.Equal, test.Sprintf("expected written bytes to equal original bytes")) n, err := w.Status() test.EqOp(t, int64(len(allInput)), n, test.Sprint("expected total bytes written to equal total length of inputs")) test.NoError(t, err) }) t.Run("fails without writing if error was already returned, but tries again after clearing error", func(t *testing.T) { w := new(FailingWriter) sc := shortcircuit.NewWriter(w) b := []byte("abcdefghi") expectedTotal := int64(len(b)) // Succeed n, err := sc.Write(b) must.EqOp(t, len(b), n) must.NoError(t, err) must.EqOp(t, 1, w.WriteCalls()) must.False(t, sc.Failed()) // Fail const limit = 3 w.FailAfter(limit) expectedTotal += limit n, err = sc.Write(b) must.EqOp(t, limit, n) must.EqOp(t, ErrFakeFailure, err) must.True(t, sc.Failed()) must.EqOp(t, 2, w.WriteCalls()) total, err := sc.Status() must.EqOp(t, ErrFakeFailure, err) must.EqOp(t, expectedTotal, total) // Fail again n, err = sc.Write(b) must.EqOp(t, 0, n) must.ErrorIs(t, err, shortcircuit.Error{}) must.ErrorIs(t, err, ErrFakeFailure) must.True(t, sc.Failed()) must.EqOp(t, 2, w.WriteCalls()) total, err = sc.Status() must.EqOp(t, ErrFakeFailure, err) must.EqOp(t, expectedTotal, total) // Clear and succeed w.Succeed() sc.ClearError() expectedTotal += int64(len(b)) n, err = sc.Write(b) must.EqOp(t, len(b), n) must.NoError(t, err) must.False(t, sc.Failed()) must.EqOp(t, 3, w.WriteCalls()) total, err = sc.Status() must.EqOp(t, expectedTotal, total) must.NoError(t, err) }) }