// Copyright 2013 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package sshstorage

import (
	"bufio"
	"bytes"
	"encoding/base64"
	"fmt"
	"io"
	"io/ioutil"
	"path"
	"sort"
	"strconv"
	"strings"

	"github.com/juju/errors"
	"github.com/juju/loggo"
	"github.com/juju/utils"

	"github.com/juju/juju/utils/ssh"
)

var logger = loggo.GetLogger("juju.environs.sshstorage")

// base64LineLength is the default line length for wrapping
// output generated by the base64 command line utility.
const base64LineLength = 76

// SSHStorage implements storage.Storage.
//
// The storage is created under sudo, and ownership given over to the
// login uid/gid. This is done so that we don't require sudo, and by
// consequence, don't require a pty, so we can interact with a script
// via stdin.
type SSHStorage struct {
	host       string
	remotepath string
	tmpdir     string

	cmd     *ssh.Cmd
	stdin   io.WriteCloser
	stdout  io.ReadCloser
	scanner *bufio.Scanner
}

var sshCommand = func(host string, command ...string) *ssh.Cmd {
	return ssh.Command(host, command, nil)
}

type flockmode string

const (
	flockShared    flockmode = "-s"
	flockExclusive flockmode = "-x"
)

type NewSSHStorageParams struct {
	// Host is the host to connect to, in the format [user@]hostname.
	Host string

	// StorageDir is the root of the remote storage directory.
	StorageDir string

	// TmpDir is the remote temporary directory for storage.
	// A temporary directory must be specified, and should be located on the
	// same filesystem as the storage directory to ensure atomic writes.
	// The temporary directory will be created when NewSSHStorage is invoked
	// if it doesn't already exist; it will never be removed. NewSSHStorage
	// will attempt to reassign ownership to the login user, and will return
	// an error if it cannot do so.
	TmpDir string
}

// NewSSHStorage creates a new SSHStorage, connected to the
// specified host, managing state under the specified remote path.
func NewSSHStorage(params NewSSHStorageParams) (*SSHStorage, error) {
	if params.StorageDir == "" {
		return nil, errors.New("storagedir must be specified and non-empty")
	}
	if params.TmpDir == "" {
		return nil, errors.New("tmpdir must be specified and non-empty")
	}

	script := fmt.Sprintf(
		"install -d -g $SUDO_GID -o $SUDO_UID %s %s",
		utils.ShQuote(params.StorageDir),
		utils.ShQuote(params.TmpDir),
	)

	cmd := sshCommand(params.Host, "sudo", "-n", "/bin/bash")
	var stderr bytes.Buffer
	cmd.Stderr = &stderr
	cmd.Stdin = strings.NewReader(script)
	if err := cmd.Run(); err != nil {
		err = fmt.Errorf("failed to create storage dir: %v (%v)", err, strings.TrimSpace(stderr.String()))
		return nil, err
	}

	// We could use sftp, but then we'd be at the mercy of
	// sftp's output messages for checking errors. Instead,
	// we execute an interactive bash shell.
	cmd = sshCommand(params.Host, "bash")
	stdin, err := cmd.StdinPipe()
	if err != nil {
		return nil, err
	}
	stdout, err := cmd.StdoutPipe()
	if err != nil {
		stdin.Close()
		return nil, err
	}
	// Combine stdout and stderr, so we can easily
	// get at catastrophic failure messages.
	cmd.Stderr = cmd.Stdout
	stor := &SSHStorage{
		host:       params.Host,
		remotepath: params.StorageDir,
		tmpdir:     params.TmpDir,
		cmd:        cmd,
		stdin:      stdin,
		stdout:     stdout,
		scanner:    bufio.NewScanner(stdout),
	}
	cmd.Start()

	// Verify we have write permissions.
	_, err = stor.runf(flockExclusive, "touch %s", utils.ShQuote(params.StorageDir))
	if err != nil {
		stdin.Close()
		stdout.Close()
		cmd.Wait()
		return nil, err
	}
	return stor, nil
}

// Close cleanly terminates the underlying SSH connection.
func (s *SSHStorage) Close() error {
	s.stdin.Close()
	s.stdout.Close()
	return s.cmd.Wait()
}

func (s *SSHStorage) runf(flockmode flockmode, command string, args ...interface{}) (string, error) {
	command = fmt.Sprintf(command, args...)
	return s.run(flockmode, command, nil, 0)
}

// terminate closes the stdin, and appends any output to the input error.
func (s *SSHStorage) terminate(err error) error {
	s.stdin.Close()
	var output string
	for s.scanner.Scan() {
		if len(output) > 0 {
			output += "\n"
		}
		output += s.scanner.Text()
	}
	if len(output) > 0 {
		err = fmt.Errorf("%v (output: %q)", err, output)
	}
	return err
}

func (s *SSHStorage) run(flockmode flockmode, command string, input io.Reader, inputlen int64) (string, error) {
	const rcPrefix = "JUJU-RC: "
	command = fmt.Sprintf(
		"SHELL=/bin/bash flock %s %s -c %s",
		flockmode,
		utils.ShQuote(s.remotepath),
		utils.ShQuote(command),
	)
	stdin := bufio.NewWriter(s.stdin)
	if input != nil {
		command = fmt.Sprintf("base64 -d << '@EOF' | (%s)", command)
	}
	command = fmt.Sprintf("(%s) 2>&1; echo %s$?", command, rcPrefix)
	if _, err := stdin.WriteString(command + "\n"); err != nil {
		return "", fmt.Errorf("failed to write command: %v", err)
	}
	if input != nil {
		if err := copyAsBase64(stdin, input); err != nil {
			return "", s.terminate(fmt.Errorf("failed to write input: %v", err))
		}
	}
	if err := stdin.Flush(); err != nil {
		return "", s.terminate(fmt.Errorf("failed to write input: %v", err))
	}
	var output []string
	for s.scanner.Scan() {
		line := s.scanner.Text()
		if strings.HasPrefix(line, rcPrefix) {
			line := line[len(rcPrefix):]
			rc, err := strconv.Atoi(line)
			if err != nil {
				return "", fmt.Errorf("failed to parse exit code %q: %v", line, err)
			}
			outputJoined := strings.Join(output, "\n")
			if rc == 0 {
				return outputJoined, nil
			}
			return "", SSHStorageError{outputJoined, rc}
		} else {
			output = append(output, line)
		}
	}

	err := fmt.Errorf("failed to locate %q", rcPrefix)
	if len(output) > 0 {
		err = fmt.Errorf("%v (output: %q)", err, strings.Join(output, "\n"))
	}
	if scannerErr := s.scanner.Err(); scannerErr != nil {
		err = fmt.Errorf("%v (scanner error: %v)", err, scannerErr)
	}
	return "", err
}

func copyAsBase64(w *bufio.Writer, r io.Reader) error {
	wrapper := newLineWrapWriter(w, base64LineLength)
	encoder := base64.NewEncoder(base64.StdEncoding, wrapper)
	if _, err := io.Copy(encoder, r); err != nil {
		return err
	}
	if err := encoder.Close(); err != nil {
		return err
	}
	if _, err := w.WriteString("\n@EOF\n"); err != nil {
		return err
	}
	return nil
}

// path returns a remote absolute path for a storage object name.
func (s *SSHStorage) path(name string) (string, error) {
	remotepath := path.Clean(path.Join(s.remotepath, name))
	if !strings.HasPrefix(remotepath, s.remotepath) {
		return "", fmt.Errorf("%q escapes storage directory", name)
	}
	return remotepath, nil
}

// Get implements storage.StorageReader.Get.
func (s *SSHStorage) Get(name string) (io.ReadCloser, error) {
	logger.Debugf("getting %q from storage", name)
	path, err := s.path(name)
	if err != nil {
		return nil, err
	}
	filename := utils.ShQuote(path)
	out, err := s.runf(flockShared, "(test -e %s || (echo No such file && exit 1)) && base64 < %s", filename, filename)
	if err != nil {
		err := err.(SSHStorageError)
		if strings.Contains(err.Output, "No such file") {
			return nil, errors.NewNotFound(err, path+" not found")
		}
		return nil, err
	}
	decoded, err := base64.StdEncoding.DecodeString(out)
	if err != nil {
		return nil, err
	}
	return ioutil.NopCloser(bytes.NewBuffer(decoded)), nil
}

// List implements storage.StorageReader.List.
func (s *SSHStorage) List(prefix string) ([]string, error) {
	remotepath, err := s.path(prefix)
	if err != nil {
		return nil, err
	}
	dir, prefix := path.Split(remotepath)
	quotedDir := utils.ShQuote(dir)
	out, err := s.runf(flockShared, "(test -d %s && find %s -type f) || true", quotedDir, quotedDir)
	if err != nil {
		return nil, err
	}
	if out == "" {
		return nil, nil
	}
	var names []string
	for _, name := range strings.Split(out, "\n") {
		if strings.HasPrefix(name[len(dir):], prefix) {
			names = append(names, name[len(s.remotepath)+1:])
		}
	}
	sort.Strings(names)
	return names, nil
}

// URL implements storage.StorageReader.URL.
func (s *SSHStorage) URL(name string) (string, error) {
	path, err := s.path(name)
	if err != nil {
		return "", err
	}
	return fmt.Sprintf("sftp://%s/%s", s.host, path), nil
}

// DefaultConsistencyStrategy implements storage.StorageReader.ConsistencyStrategy.
func (s *SSHStorage) DefaultConsistencyStrategy() utils.AttemptStrategy {
	return utils.AttemptStrategy{}
}

// ShouldRetry is specified in the StorageReader interface.
func (s *SSHStorage) ShouldRetry(err error) bool {
	return false
}

// Put implements storage.StorageWriter.Put
func (s *SSHStorage) Put(name string, r io.Reader, length int64) error {
	logger.Debugf("putting %q (len %d) to storage", name, length)
	path, err := s.path(name)
	if err != nil {
		return err
	}
	path = utils.ShQuote(path)
	tmpdir := utils.ShQuote(s.tmpdir)

	// Write to a temporary file ($TMPFILE), then mv atomically.
	command := fmt.Sprintf("mkdir -p `dirname %s` && cat > $TMPFILE", path)
	command = fmt.Sprintf(
		"TMPFILE=`mktemp --tmpdir=%s` && ((%s && mv $TMPFILE %s) || rm -f $TMPFILE)",
		tmpdir, command, path,
	)

	_, err = s.run(flockExclusive, command+"\n", r, length)
	return err
}

// Remove implements storage.StorageWriter.Remove
func (s *SSHStorage) Remove(name string) error {
	path, err := s.path(name)
	if err != nil {
		return err
	}
	path = utils.ShQuote(path)
	_, err = s.runf(flockExclusive, "rm -f %s", path)
	return err
}

// RemoveAll implements storage.StorageWriter.RemoveAll
func (s *SSHStorage) RemoveAll() error {
	_, err := s.runf(flockExclusive, "rm -fr %s/*", utils.ShQuote(s.remotepath))
	return err
}
