internal/plugin: refactor plugin handling code

This commit is contained in:
Filippo Valsorda
2021-12-26 05:12:32 +01:00
parent 5a0da177e9
commit 87a982b72e

View File

@@ -30,18 +30,7 @@ type Recipient struct {
// identity is true when encoding is an identity string.
identity bool
// DisplayMessage is a callback that will be invoked by Wrap if the plugin
// wishes to display a message to the user. If DisplayMessage is nil or
// returns an error, failure will be reported to the plugin.
DisplayMessage func(message string) error
// RequestValue is a callback that will be invoked by Wrap if the plugin
// wishes to request a value from the user. If RequestValue is nil or
// returns an error, failure will be reported to the plugin.
RequestValue func(message string, secret bool) (string, error)
// Confirm is a callback that will be invoked by Unwrap if the plugin wishes
// to request a confirmation from the user. If Confirm is nil or returns an
// error, failure will be reported to the plugin.
Confirm func(message, yes, no string) (bool, error)
ClientUI
}
var _ age.Recipient = &Recipient{}
@@ -81,29 +70,17 @@ func (r *Recipient) Wrap(fileKey []byte) (stanzas []*age.Stanza, err error) {
defer conn.Close()
// Phase 1: client sends recipient or identity and file key
s := &format.Stanza{
Type: "add-recipient",
Args: []string{r.encoding},
}
addType := "add-recipient"
if r.identity {
s.Type = "add-identity"
addType = "add-identity"
}
if err := s.Marshal(conn); err != nil {
if err := writeStanza(conn, addType, r.encoding); err != nil {
return nil, err
}
s = &format.Stanza{
Type: "wrap-file-key",
Body: fileKey,
}
if err := s.Marshal(conn); err != nil {
if err := writeStanzaWithBody(conn, "wrap-file-key", fileKey); err != nil {
return nil, err
}
s = &format.Stanza{
Type: "done",
}
if err := s.Marshal(conn); err != nil {
if err := writeStanza(conn, "done"); err != nil {
return nil, err
}
@@ -117,95 +94,17 @@ ReadLoop:
}
switch s.Type {
case "msg":
if r.DisplayMessage == nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
break
}
if err := r.DisplayMessage(string(s.Body)); err != nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
} else {
ss := &format.Stanza{Type: "ok"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
}
case "request-secret", "request-public":
if r.RequestValue == nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
break
}
msg := string(s.Body)
if secret, err := r.RequestValue(msg, s.Type == "request-secret"); err != nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
} else {
ss := &format.Stanza{Type: "ok", Body: []byte(secret)}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
}
case "confirm":
if len(s.Args) != 1 && len(s.Args) != 2 {
return nil, fmt.Errorf("received malformed confirm stanza")
}
if r.Confirm == nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
break
}
yes, err := format.DecodeString(s.Args[0])
if err != nil {
return nil, fmt.Errorf("received malformed confirm stanza")
}
var no []byte
if len(s.Args) == 2 {
no, err = format.DecodeString(s.Args[1])
if err != nil {
return nil, fmt.Errorf("received malformed confirm stanza")
}
}
msg := string(s.Body)
if selection, err := r.Confirm(msg, string(yes), string(no)); err != nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
} else {
ss := &format.Stanza{Type: "ok"}
if selection {
ss.Args = append(ss.Args, "yes")
} else {
ss.Args = append(ss.Args, "no")
}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
}
case "recipient-stanza":
if len(s.Args) < 2 {
return nil, fmt.Errorf("received malformed recipient stanza")
return nil, fmt.Errorf("malformed recipient stanza: unexpected argument count")
}
n, err := strconv.Atoi(s.Args[0])
if err != nil {
return nil, fmt.Errorf("received malformed recipient stanza")
return nil, fmt.Errorf("malformed recipient stanza: invalid index")
}
// We only send a single file key, so the index must be 0.
if n != 0 {
return nil, fmt.Errorf("received malformed recipient stanza")
return nil, fmt.Errorf("malformed recipient stanza: unexpected index")
}
stanzas = append(stanzas, &age.Stanza{
@@ -214,23 +113,24 @@ ReadLoop:
Body: s.Body,
})
ss := &format.Stanza{Type: "ok"}
if err := ss.Marshal(conn); err != nil {
if err := writeStanza(conn, "ok"); err != nil {
return nil, err
}
case "error":
ss := &format.Stanza{Type: "ok"}
if err := ss.Marshal(conn); err != nil {
if err := writeStanza(conn, "ok"); err != nil {
return nil, err
}
return nil, fmt.Errorf("%q", s.Body)
return nil, fmt.Errorf("%s", s.Body)
case "done":
break ReadLoop
default:
ss := &format.Stanza{Type: "unsupported"}
if err := ss.Marshal(conn); err != nil {
if ok, err := r.handleUI(conn, s); err != nil {
return nil, err
} else if !ok {
if err := writeStanza(conn, "unsupported"); err != nil {
return nil, err
}
}
}
}
@@ -246,18 +146,7 @@ type Identity struct {
name string
encoding string
// DisplayMessage is a callback that will be invoked by Unwrap if the plugin
// wishes to display a message to the user. If DisplayMessage is nil or
// returns an error, failure will be reported to the plugin.
DisplayMessage func(message string) error
// RequestValue is a callback that will be invoked by Unwrap if the plugin
// wishes to request a value from the user. If RequestValue is nil or
// returns an error, failure will be reported to the plugin.
RequestValue func(message string, secret bool) (string, error)
// Confirm is a callback that will be invoked by Unwrap if the plugin wishes
// to request a confirmation from the user. If Confirm is nil or returns an
// error, failure will be reported to the plugin.
Confirm func(message, yes, no string) (bool, error)
ClientUI
}
var _ age.Identity = &Identity{}
@@ -292,10 +181,7 @@ func (i *Identity) Recipient() *Recipient {
name: i.name,
encoding: i.encoding,
identity: true,
DisplayMessage: i.DisplayMessage,
RequestValue: i.RequestValue,
Confirm: i.Confirm,
ClientUI: i.ClientUI,
}
}
@@ -313,16 +199,11 @@ func (i *Identity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) {
defer conn.Close()
// Phase 1: client sends the plugin the identity string and the stanzas
s := &format.Stanza{
Type: "add-identity",
Args: []string{i.encoding},
}
if err := s.Marshal(conn); err != nil {
if err := writeStanza(conn, "add-identity", i.encoding); err != nil {
return nil, err
}
for _, rs := range stanzas {
s = &format.Stanza{
s := &format.Stanza{
Type: "recipient-stanza",
Args: append([]string{"0", rs.Type}, rs.Args...),
Body: rs.Body,
@@ -331,11 +212,7 @@ func (i *Identity) Unwrap(stanzas []*age.Stanza) (fileKey []byte, err error) {
return nil, err
}
}
s = &format.Stanza{
Type: "done",
}
if err := s.Marshal(conn); err != nil {
if err := writeStanza(conn, "done"); err != nil {
return nil, err
}
@@ -349,95 +226,17 @@ ReadLoop:
}
switch s.Type {
case "msg":
if i.DisplayMessage == nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
break
}
if err := i.DisplayMessage(string(s.Body)); err != nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
} else {
ss := &format.Stanza{Type: "ok"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
}
case "request-secret", "request-public":
if i.RequestValue == nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
break
}
msg := string(s.Body)
if secret, err := i.RequestValue(msg, s.Type == "request-secret"); err != nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
} else {
ss := &format.Stanza{Type: "ok", Body: []byte(secret)}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
}
case "confirm":
if len(s.Args) != 1 && len(s.Args) != 2 {
return nil, fmt.Errorf("received malformed confirm stanza")
}
if i.Confirm == nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
break
}
yes, err := format.DecodeString(s.Args[0])
if err != nil {
return nil, fmt.Errorf("received malformed confirm stanza")
}
var no []byte
if len(s.Args) == 2 {
no, err = format.DecodeString(s.Args[1])
if err != nil {
return nil, fmt.Errorf("received malformed confirm stanza")
}
}
msg := string(s.Body)
if selection, err := i.Confirm(msg, string(yes), string(no)); err != nil {
ss := &format.Stanza{Type: "fail"}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
} else {
ss := &format.Stanza{Type: "ok"}
if selection {
ss.Args = append(ss.Args, "yes")
} else {
ss.Args = append(ss.Args, "no")
}
if err := ss.Marshal(conn); err != nil {
return nil, err
}
}
case "file-key":
if len(s.Args) != 1 {
return nil, fmt.Errorf("received malformed file-key stanza")
return nil, fmt.Errorf("malformed file-key stanza: unexpected arguments count")
}
n, err := strconv.Atoi(s.Args[0])
if err != nil {
return nil, fmt.Errorf("received malformed file-key stanza")
return nil, fmt.Errorf("malformed file-key stanza: invalid index")
}
// We only send a single file key, so the index must be 0.
if n != 0 {
return nil, fmt.Errorf("received malformed file-key stanza")
return nil, fmt.Errorf("malformed file-key stanza: unexpected index")
}
if fileKey != nil {
return nil, fmt.Errorf("received duplicated file-key stanza")
@@ -445,23 +244,24 @@ ReadLoop:
fileKey = s.Body
ss := &format.Stanza{Type: "ok"}
if err := ss.Marshal(conn); err != nil {
if err := writeStanza(conn, "ok"); err != nil {
return nil, err
}
case "error":
ss := &format.Stanza{Type: "ok"}
if err := ss.Marshal(conn); err != nil {
if err := writeStanza(conn, "ok"); err != nil {
return nil, err
}
return nil, fmt.Errorf("%q", s.Body)
return nil, fmt.Errorf("%s", s.Body)
case "done":
break ReadLoop
default:
ss := &format.Stanza{Type: "unsupported"}
if err := ss.Marshal(conn); err != nil {
if ok, err := i.handleUI(conn, s); err != nil {
return nil, err
} else if !ok {
if err := writeStanza(conn, "unsupported"); err != nil {
return nil, err
}
}
}
}
@@ -472,6 +272,75 @@ ReadLoop:
return fileKey, nil
}
// ClientUI holds callbacks that will be invoked by (Un)Wrap if the plugin
// wishes to interact with the user. If any of them is nil or returns an error,
// failure will be reported to the plugin.
type ClientUI struct {
// DisplayMessage displays the message, which is expected to have lowercase
// initials and no final period.
DisplayMessage func(message string) error
// RequestValue requests a secret or public input, with the provided prompt.
RequestValue func(prompt string, secret bool) (string, error)
// Confirm requests a confirmation with the provided prompt. The yes and no
// value are the choices provided to the user. no may be empty. The return
// value indicates whether the user selected the yes or no option.
Confirm func(prompt, yes, no string) (choseYes bool, err error)
}
func (c *ClientUI) handleUI(conn *clientConnection, s *format.Stanza) (ok bool, err error) {
// TODO: surface non-fatal but probably useful errors.
switch s.Type {
case "msg":
if c.DisplayMessage == nil {
return true, writeStanza(conn, "fail")
}
if err := c.DisplayMessage(string(s.Body)); err != nil {
return true, writeStanza(conn, "fail")
}
return true, writeStanza(conn, "ok")
case "request-secret", "request-public":
if c.RequestValue == nil {
return true, writeStanza(conn, "fail")
}
secret, err := c.RequestValue(string(s.Body), s.Type == "request-secret")
if err != nil {
return true, writeStanza(conn, "fail")
}
return true, writeStanzaWithBody(conn, "ok", []byte(secret))
case "confirm":
if len(s.Args) != 1 && len(s.Args) != 2 {
return true, fmt.Errorf("malformed confirm stanza: unexpected number of arguments")
}
if c.Confirm == nil {
return true, writeStanza(conn, "fail")
}
yes, err := format.DecodeString(s.Args[0])
if err != nil {
return true, fmt.Errorf("malformed confirm stanza: invalid YES option encoding")
}
var no []byte
if len(s.Args) == 2 {
no, err = format.DecodeString(s.Args[1])
if err != nil {
return true, fmt.Errorf("malformed confirm stanza: invalid NO option encoding")
}
}
choseYes, err := c.Confirm(string(s.Body), string(yes), string(no))
if err != nil {
return true, writeStanza(conn, "fail")
}
result := "yes"
if !choseYes {
result = "no"
}
return true, writeStanza(conn, "ok", result)
default:
return false, nil
}
}
type clientConnection struct {
cmd *exec.Cmd
io.Reader // stdout
@@ -527,3 +396,13 @@ func (cc *clientConnection) Close() error {
cc.cmd.Process.Signal(os.Interrupt)
return cc.cmd.Wait()
}
func writeStanza(conn io.Writer, t string, args ...string) error {
s := &format.Stanza{Type: t, Args: args}
return s.Marshal(conn)
}
func writeStanzaWithBody(conn io.Writer, t string, body []byte) error {
s := &format.Stanza{Type: t, Body: body}
return s.Marshal(conn)
}