Initial commit
diff --git a/httprr/rr_test.go b/httprr/rr_test.go
new file mode 100644
index 0000000..b20bc7d
--- /dev/null
+++ b/httprr/rr_test.go
@@ -0,0 +1,336 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httprr
+
+import (
+	"bytes"
+	"errors"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"strings"
+	"testing"
+	"testing/iotest"
+)
+
+func handler(w http.ResponseWriter, r *http.Request) {
+	if strings.HasSuffix(r.URL.Path, "/redirect") {
+		http.Error(w, "redirect me!", 304)
+		return
+	}
+	if r.Method == "GET" {
+		if r.Header.Get("Secret") != "key" {
+			http.Error(w, "missing secret", 666)
+			return
+		}
+	}
+	if r.Method == "POST" {
+		data, err := io.ReadAll(r.Body)
+		if err != nil {
+			panic(err)
+		}
+		if !strings.Contains(string(data), "my Secret") {
+			http.Error(w, "missing body secret", 667)
+			return
+		}
+	}
+}
+
+func always555(w http.ResponseWriter, r *http.Request) {
+	http.Error(w, "should not be making HTTP requests", 555)
+}
+
+func dropPort(r *http.Request) error {
+	if r.URL.Port() != "" {
+		r.URL.Host = r.URL.Host[:strings.LastIndex(r.URL.Host, ":")]
+		r.Host = r.Host[:strings.LastIndex(r.Host, ":")]
+	}
+	return nil
+}
+
+func dropSecretHeader(r *http.Request) error {
+	r.Header.Del("Secret")
+	return nil
+}
+
+func hideSecretBody(r *http.Request) error {
+	if r.Body != nil {
+		body := r.Body.(*Body)
+		body.Data = []byte("redacted")
+	}
+	return nil
+}
+
+func doNothing(b *bytes.Buffer) error {
+	return nil
+}
+
+func doRefresh(b *bytes.Buffer) error {
+	s := b.String()
+	b.Reset()
+	_, _ = b.WriteString(s)
+	return nil
+}
+
+func TestRecordReplay(t *testing.T) {
+	dir := t.TempDir()
+	file := dir + "/rr"
+
+	// 4 passes:
+	//	0: create
+	//	1: open
+	//	2: Open with -httprecord="r+"
+	//	3: Open with -httprecord=""
+	for pass := range 4 {
+		start := open
+		h := always555
+		*record = ""
+		switch pass {
+		case 0:
+			start = create
+			h = handler
+		case 2:
+			start = Open
+			*record = "r+"
+			h = handler
+		case 3:
+			start = Open
+		}
+		rr, err := start(file, http.DefaultTransport)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if rr.Recording() {
+			t.Log("RECORDING")
+		} else {
+			t.Log("REPLAYING")
+		}
+		rr.ScrubReq(dropPort, dropSecretHeader)
+		rr.ScrubReq(hideSecretBody)
+		rr.ScrubResp(doNothing, doRefresh)
+
+		mustNewRequest := func(method, url string, body io.Reader) *http.Request {
+			req, err := http.NewRequest(method, url, body)
+			if err != nil {
+				t.Helper()
+				t.Fatal(err)
+			}
+			return req
+		}
+
+		mustDo := func(req *http.Request, status int) {
+			resp, err := rr.Client().Do(req)
+			if err != nil {
+				t.Helper()
+				t.Fatal(err)
+			}
+			body, _ := io.ReadAll(resp.Body)
+			resp.Body.Close()
+			if resp.StatusCode != status {
+				t.Helper()
+				t.Fatalf("%v: %s\n%s", req.URL, resp.Status, body)
+			}
+		}
+
+		srv := httptest.NewServer(http.HandlerFunc(h))
+		defer srv.Close()
+
+		req := mustNewRequest("GET", srv.URL+"/myrequest", nil)
+		req.Header.Set("Secret", "key")
+		mustDo(req, 200)
+
+		req = mustNewRequest("POST", srv.URL+"/myrequest", strings.NewReader("my Secret"))
+		mustDo(req, 200)
+
+		req = mustNewRequest("GET", srv.URL+"/redirect", nil)
+		mustDo(req, 304)
+
+		if !rr.Recording() {
+			req = mustNewRequest("GET", srv.URL+"/uncached", nil)
+			resp, err := rr.Client().Do(req)
+			if err == nil {
+				body, _ := io.ReadAll(resp.Body)
+				t.Fatalf("%v: %s\n%s", req.URL, resp.Status, body)
+			}
+		}
+
+		if err := rr.Close(); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	data, err := os.ReadFile(file)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if strings.Contains(string(data), "Secret") {
+		t.Fatalf("rr file contains Secret:\n%s", data)
+	}
+}
+
+var badResponseTrace = []byte("httprr trace v1\n" +
+	"92 75\n" +
+	"GET http://127.0.0.1/myrequest HTTP/1.1\r\n" +
+	"Host: 127.0.0.1\r\n" +
+	"User-Agent: Go-http-client/1.1\r\n" +
+	"\r\n" +
+	"HZZP/1.1 200 OK\r\n" +
+	"Date: Wed, 12 Jun 2024 13:55:02 GMT\r\n" +
+	"Content-Length: 0\r\n" +
+	"\r\n")
+
+func TestErrors(t *testing.T) {
+	dir := t.TempDir()
+
+	makeTmpFile := func() string {
+		f, err := os.CreateTemp(dir, "TestErrors")
+		if err != nil {
+			t.Fatalf("failed to create tmp file for test: %v", err)
+		}
+		name := f.Name()
+		f.Close()
+		return name
+	}
+
+	// -httprecord regexp parsing
+	*record = "+"
+	if _, err := Open(makeTmpFile(), nil); err == nil || !strings.Contains(err.Error(), "invalid -httprecord flag") {
+		t.Errorf("did not diagnose bad -httprecord: err = %v", err)
+	}
+	*record = ""
+
+	// invalid httprr trace
+	if _, err := Open(makeTmpFile(), nil); err == nil || !strings.Contains(err.Error(), "not an httprr trace") {
+		t.Errorf("did not diagnose invalid httprr trace: err = %v", err)
+	}
+
+	// corrupt httprr trace
+	corruptTraceFile := makeTmpFile()
+	os.WriteFile(corruptTraceFile, []byte("httprr trace v1\ngarbage\n"), 0o666)
+	if _, err := Open(corruptTraceFile, nil); err == nil || !strings.Contains(err.Error(), "corrupt httprr trace") {
+		t.Errorf("did not diagnose invalid httprr trace: err = %v", err)
+	}
+
+	// os.Create error creating trace
+	if _, err := create("invalid\x00file", nil); err == nil {
+		t.Errorf("did not report failure from os.Create: err = %v", err)
+	}
+
+	// os.ReadAll error reading trace
+	if _, err := open("nonexistent", nil); err == nil {
+		t.Errorf("did not report failure from os.ReadFile: err = %v", err)
+	}
+
+	// error reading body
+	rr, err := create(makeTmpFile(), nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := rr.Client().Post("http://127.0.0.1/nonexist", "x/error", iotest.ErrReader(errors.New("MY ERROR"))); err == nil || !strings.Contains(err.Error(), "MY ERROR") {
+		t.Errorf("did not report failure from io.ReadAll(body): err = %v", err)
+	}
+
+	// error during request scrub
+	rr.ScrubReq(func(*http.Request) error { return errors.New("SCRUB ERROR") })
+	if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "SCRUB ERROR") {
+		t.Errorf("did not report failure from scrub: err = %v", err)
+	}
+	rr.Close()
+
+	// error during response scrub
+	rr.ScrubResp(func(*bytes.Buffer) error { return errors.New("SCRUB ERROR") })
+	if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "SCRUB ERROR") {
+		t.Errorf("did not report failure from scrub: err = %v", err)
+	}
+	rr.Close()
+
+	// error during rkey.WriteProxy
+	rr, err = create(makeTmpFile(), nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	rr.ScrubReq(func(req *http.Request) error {
+		req.URL = nil
+		req.Host = ""
+		return nil
+	})
+	rr.ScrubResp(func(b *bytes.Buffer) error {
+		b.Reset()
+		return nil
+	})
+	if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "no Host or URL set") {
+		t.Errorf("did not report failure from rkey.WriteProxy: err = %v", err)
+	}
+	rr.Close()
+
+	// error during resp.Write
+	rr, err = create(makeTmpFile(), badRespTransport{})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := rr.Client().Get("http://127.0.0.1/nonexist"); err == nil || !strings.Contains(err.Error(), "TRANSPORT ERROR") {
+		t.Errorf("did not report failure from resp.Write: err = %v", err)
+	}
+	rr.Close()
+
+	// error during Write logging request
+	srv := httptest.NewServer(http.HandlerFunc(always555))
+	defer srv.Close()
+	rr, err = create(makeTmpFile(), http.DefaultTransport)
+	if err != nil {
+		t.Fatal(err)
+	}
+	rr.ScrubReq(dropPort)
+	rr.record.Close() // cause write error
+	if _, err := rr.Client().Get(srv.URL + "/redirect"); err == nil || !strings.Contains(err.Error(), "file already closed") {
+		t.Errorf("did not report failure from record write: err = %v", err)
+	}
+	rr.writeErr = errors.New("BROKEN ERROR")
+	if _, err := rr.Client().Get(srv.URL + "/redirect"); err == nil || !strings.Contains(err.Error(), "BROKEN ERROR") {
+		t.Errorf("did not report previous write failure: err = %v", err)
+	}
+	if err := rr.Close(); err == nil || !strings.Contains(err.Error(), "BROKEN ERROR") {
+		t.Errorf("did not report write failure during close: err = %v", err)
+	}
+
+	// error during RoundTrip
+	rr, err = create(makeTmpFile(), errTransport{errors.New("TRANSPORT ERROR")})
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := rr.Client().Get(srv.URL); err == nil || !strings.Contains(err.Error(), "TRANSPORT ERROR") {
+		t.Errorf("did not report failure from transport: err = %v", err)
+	}
+	rr.Close()
+
+	// error during http.ReadResponse: trace is structurally okay but has malformed response inside
+	tmpFile := makeTmpFile()
+	if err := os.WriteFile(tmpFile, badResponseTrace, 0o666); err != nil {
+		t.Fatal(err)
+	}
+	rr, err = Open(tmpFile, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := rr.Client().Get("http://127.0.0.1/myrequest"); err == nil || !strings.Contains(err.Error(), "corrupt httprr trace:") {
+		t.Errorf("did not diagnose invalid httprr trace: err = %v", err)
+	}
+	rr.Close()
+}
+
+type errTransport struct{ err error }
+
+func (e errTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+	return nil, e.err
+}
+
+type badRespTransport struct{}
+
+func (badRespTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+	resp := new(http.Response)
+	resp.Body = io.NopCloser(iotest.ErrReader(errors.New("TRANSPORT ERROR")))
+	return resp, nil
+}