128 lines
2.9 KiB
Go
128 lines
2.9 KiB
Go
package shortcircuit_test
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"slices"
|
|
"testing"
|
|
|
|
"git.codemonkeysoftware.net/b/peachy-go/shortcircuit"
|
|
"github.com/shoenig/test"
|
|
"github.com/shoenig/test/must"
|
|
)
|
|
|
|
var ErrFakeFailure = errors.New("operation successfully failed")
|
|
|
|
type FailingWriter struct {
|
|
fail bool
|
|
failAfter int
|
|
writeCalls int
|
|
}
|
|
|
|
func (f *FailingWriter) Fail() {
|
|
f.FailAfter(0)
|
|
}
|
|
|
|
func (f *FailingWriter) FailAfter(nBytes int) {
|
|
f.failAfter = nBytes
|
|
f.fail = true
|
|
}
|
|
|
|
func (f *FailingWriter) Succeed() {
|
|
f.fail = false
|
|
}
|
|
|
|
func (f *FailingWriter) Write(p []byte) (int, error) {
|
|
f.writeCalls++
|
|
if !f.fail {
|
|
return len(p), nil
|
|
}
|
|
if f.failAfter > len(p) {
|
|
f.failAfter -= len(p)
|
|
return len(p), nil
|
|
}
|
|
n := f.failAfter
|
|
f.failAfter = 0
|
|
return n, ErrFakeFailure
|
|
}
|
|
|
|
func (f *FailingWriter) WriteCalls() int {
|
|
return f.writeCalls
|
|
}
|
|
|
|
func TestWriter(t *testing.T) {
|
|
t.Run("writes multiple times and counts total bytes written", func(t *testing.T) {
|
|
inputs := [][]byte{
|
|
[]byte("abcdefghi"),
|
|
[]byte("jklmnopqr"),
|
|
[]byte("stuvwxyz"),
|
|
}
|
|
var buf bytes.Buffer
|
|
w := shortcircuit.NewWriter(&buf)
|
|
for _, b := range inputs {
|
|
n, err := w.Write(slices.Clone(b))
|
|
must.NoError(t, err)
|
|
must.False(t, w.Failed())
|
|
must.EqOp(t, len(b), n, must.Sprint("expected Write to return length of byte slice"))
|
|
}
|
|
allInput := slices.Concat(inputs...)
|
|
test.EqFunc(t, allInput, buf.Bytes(), bytes.Equal, test.Sprintf("expected written bytes to equal original bytes"))
|
|
n, err := w.Status()
|
|
test.EqOp(t, int64(len(allInput)), n, test.Sprint("expected total bytes written to equal total length of inputs"))
|
|
test.NoError(t, err)
|
|
})
|
|
|
|
t.Run("fails without writing if error was already returned, but tries again after clearing error", func(t *testing.T) {
|
|
w := new(FailingWriter)
|
|
sc := shortcircuit.NewWriter(w)
|
|
b := []byte("abcdefghi")
|
|
expectedTotal := int64(len(b))
|
|
|
|
// Succeed
|
|
n, err := sc.Write(b)
|
|
must.EqOp(t, len(b), n)
|
|
must.NoError(t, err)
|
|
must.EqOp(t, 1, w.WriteCalls())
|
|
must.False(t, sc.Failed())
|
|
|
|
// Fail
|
|
const limit = 3
|
|
w.FailAfter(limit)
|
|
expectedTotal += limit
|
|
|
|
n, err = sc.Write(b)
|
|
must.EqOp(t, limit, n)
|
|
must.EqOp(t, ErrFakeFailure, err)
|
|
must.True(t, sc.Failed())
|
|
must.EqOp(t, 2, w.WriteCalls())
|
|
total, err := sc.Status()
|
|
must.EqOp(t, ErrFakeFailure, err)
|
|
must.EqOp(t, expectedTotal, total)
|
|
|
|
// Fail again
|
|
n, err = sc.Write(b)
|
|
must.EqOp(t, 0, n)
|
|
must.ErrorIs(t, err, shortcircuit.Error{})
|
|
must.ErrorIs(t, err, ErrFakeFailure)
|
|
must.True(t, sc.Failed())
|
|
must.EqOp(t, 2, w.WriteCalls())
|
|
total, err = sc.Status()
|
|
must.EqOp(t, ErrFakeFailure, err)
|
|
must.EqOp(t, expectedTotal, total)
|
|
|
|
// Clear and succeed
|
|
w.Succeed()
|
|
sc.ClearError()
|
|
expectedTotal += int64(len(b))
|
|
|
|
n, err = sc.Write(b)
|
|
must.EqOp(t, len(b), n)
|
|
must.NoError(t, err)
|
|
must.False(t, sc.Failed())
|
|
must.EqOp(t, 3, w.WriteCalls())
|
|
total, err = sc.Status()
|
|
must.EqOp(t, expectedTotal, total)
|
|
must.NoError(t, err)
|
|
})
|
|
}
|