// Copyright (c) 2021 Oasis Labs Inc. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
// TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
// TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

package h2c

import (
	"crypto"
	"fmt"
	"io"
	"math"

	"golang.org/x/crypto/sha3"
)

var oversizeDST = []byte("H2C-OVERSIZE-DST-")

// ExpandMessageXMD implements expand_message_xmd, overwriting out with
// uniformly random data generated by the provided hash function, domain
// separation tag, and message.
func ExpandMessageXMD(out []byte, hFunc crypto.Hash, domainSeparator, message []byte) error {
	lenInBytes := len(out)
	bInBytes := hFunc.Size()

	h := hFunc.New()
	rInBytes := h.BlockSize()

	// 0. Ensure parameters are sensible.
	if bInBytes < 2*kay/8 {
		return fmt.Errorf("h2c: b_in_bytes insufficiently large: %d", bInBytes)
	}
	if lenInBytes == 0 || lenInBytes > math.MaxUint16 {
		return fmt.Errorf("h2c: len_in_bytes out of range: %d", lenInBytes)
	}

	// 5.4.3 Using DSTs longer than 255 bytes.
	DST := domainSeparator
	lenDST := len(domainSeparator)
	if lenDST > math.MaxUint8 {
		// DST = H("H2C-OVERSIZE-DST-" || a_very_long_DST)
		_, _ = h.Write(oversizeDST)
		_, _ = h.Write(DST)

		DST = h.Sum(nil)
		lenDST = len(DST)

		h.Reset()
	}

	// 1. ell = ceil(len_in_bytes / b_in_bytes)
	ell := (lenInBytes + bInBytes - 1) / bInBytes

	// 2. ABORT if ell > 255
	if ell > 255 {
		return fmt.Errorf("h2c: ell out of range: %d", ell)
	}

	// 7. b_0 = H(msg_prime)
	_, _ = h.Write(make([]byte, rInBytes))                             // Z_pad (I2OSP(0, r_in_bytes))
	_, _ = h.Write(message)                                            // msg
	_, _ = h.Write([]byte{byte(lenInBytes >> 8), byte(lenInBytes), 0}) // l_i_b_str || I2OSP(0, 1)
	_, _ = h.Write(DST)                                                // DST
	_, _ = h.Write([]byte{byte(lenDST)})                               // I2OSP(len(DST), 1)
	b0 := h.Sum(nil)

	// 8. b_1 = H(b_0 || I2OSP(1, 1) || DST_prime)
	h.Reset()
	_, _ = h.Write(b0)                   // b_0
	_, _ = h.Write([]byte{1})            // I2OSP(1, 1)
	_, _ = h.Write(DST)                  // DST
	_, _ = h.Write([]byte{byte(lenDST)}) // I2OSP(len(DST), 1)
	b1 := h.Sum(nil)

	// We attempt to be somewhat more efficient about the remaining steps
	// by:
	//  * Seeing if we can service the request just with b_1 (len_in_bytes <= b_in_bytes)
	//  * Allocating a b_in_bytes sized temporary buffer to store:
	//     * b_(i - 1) (Initialized to b_1)
	//     * strxor(b_0, b_(i - 1))
	//     * b_i, which becomes b_(i - 1)

	// Special case: if len_in_bytes <= b_in_bytes, we can return output
	// from b_1 and terminate.
	if lenInBytes <= bInBytes {
		copy(out, b1[:lenInBytes])
		return nil
	}

	// Reuse a temporary buffer to hold both the xored portion of the hash
	// input and b_(i - 1).
	xorBuf := make([]byte, 0, bInBytes)
	xorBuf = append(xorBuf, b1...)

	// Append b_1 to the output, since we know we need all of it.
	copy(out, b1) // 11. uniform_bytes = b_1 || ...
	outOff := len(b1)

	// 9. for i in (2, ..., ell):
	for i, wanted := 2, lenInBytes-bInBytes; wanted > 0; i++ {
		// 10. b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime)
		for i, v := range b0 {
			xorBuf[i] ^= v
		}

		h.Reset()
		_, _ = h.Write(xorBuf)               // strxor(b_0, b_(i - 1))
		_, _ = h.Write([]byte{byte(i)})      // I2OSP(i, 1)
		_, _ = h.Write(DST)                  // DST
		_, _ = h.Write([]byte{byte(lenDST)}) // I2OSP(len(DST), 1)
		h.Sum(xorBuf[:0])                    // xorBuf = b_i

		// Append up to b_in_bytes from b_i (this handles the substr)
		toAppend := wanted
		if wanted > bInBytes {
			toAppend = bInBytes
		}

		copy(out[outOff:], xorBuf[:toAppend])
		outOff += toAppend
		wanted -= toAppend
	}

	return nil
}

func newXOF(xofFunc sha3.ShakeHash) sha3.ShakeHash {
	xof := xofFunc.Clone()
	xof.Reset()

	return xof
}

// ExpandMessageXOF implements expand_message_xof, overwriting out with
// uniformly random data generated by the provided extensible-output
// function, domain separation tag, and message.
//
// Note: This needs to use the Clone() method of the XOF to instantiate
// a new instance of the XOF.  At present there are 3 different XOF
// interfaces in the x/crypto package, all mutually incompatible due
// to the return type of Clone().  Complain to the x/crypto developers,
// not me.
func ExpandMessageXOF(out []byte, xofFunc sha3.ShakeHash, domainSeparator, message []byte) error {
	lenInBytes := len(out)

	// 0. Ensure parameters are sensible.
	if lenInBytes == 0 || lenInBytes > math.MaxUint16 {
		return fmt.Errorf("h2c: len_in_bytes out of range: %d", lenInBytes)
	}

	// Get a fresh instance of the XOF to work with.
	xof := newXOF(xofFunc)

	// Feed input into the XOF.  Since we have an XOF, we can feed the
	// inputs into the XOF one-by-one instead of allocating a temporary
	// buffer.

	// 2. msg_prime = msg || I2OSP(len_in_bytes, 2) || DST_prime (appended next)
	_, _ = xof.Write(message)                                         // msg
	_, _ = xof.Write([]byte{byte(lenInBytes >> 8), byte(lenInBytes)}) // I2OSP(len_in_bytes, 2)

	// 1. DST_prime = DST || I2OSP(len(DST), 1)
	DST := domainSeparator
	lenDST := len(domainSeparator)
	if lenDST > math.MaxUint8 {
		newDST := make([]byte, 2*kay/8)

		dstXOF := newXOF(xofFunc)
		_, _ = dstXOF.Write(oversizeDST)
		_, _ = dstXOF.Write(DST)
		if _, err := io.ReadFull(dstXOF, newDST); err != nil {
			return fmt.Errorf("h2c: failed to read shortened DST: %w", err)
		}

		DST = newDST
		lenDST = len(DST)
	}
	_, _ = xof.Write(DST)                  // DST
	_, _ = xof.Write([]byte{byte(lenDST)}) // I2OSP(len(DST), 1)

	// 3. uniform_bytes = H(msg_prime, len_in_bytes)
	if _, err := io.ReadFull(xof, out); err != nil {
		return fmt.Errorf("h2c: failed to read XOF output: %w", err)
	}

	return nil
}
