// This file is part of MinIO dperf
// Copyright (c) 2021 MinIO, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package dperf

import (
	"context"
	"errors"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"syscall"
	"time"

	"github.com/minio/pkg/v3/rng"
	"github.com/ncw/directio"
	"golang.org/x/sys/unix"
)

type nullWriter struct{}

func (n nullWriter) Write(b []byte) (int, error) {
	return len(b), nil
}

func (d *DrivePerf) runReadTest(ctx context.Context, path string, data []byte) (uint64, error) {
	return d.runReadTestWithIndex(ctx, path, data, 0)
}

func (d *DrivePerf) runReadTestWithIndex(ctx context.Context, path string, data []byte, ioIndex int) (uint64, error) {
	startTime := time.Now()

	// For reads, prefer O_DIRECT to bypass page cache when possible
	// For small block sizes (< 4KiB), use regular I/O with FADV_DONTNEED to drop cache
	var flags int
	useDirectIO := d.BlockSize >= DirectioAlignSize
	if useDirectIO {
		flags = syscall.O_DIRECT | os.O_RDONLY
	} else {
		flags = os.O_RDONLY
	}

	r, err := os.OpenFile(path, flags, 0o400)
	if err != nil {
		return 0, err
	}
	unix.Fadvise(int(r.Fd()), 0, int64(d.FileSize), unix.FADV_SEQUENTIAL)

	// For non-O_DIRECT reads, advise kernel to not cache the data
	if !useDirectIO {
		unix.Fadvise(int(r.Fd()), 0, int64(d.FileSize), unix.FADV_DONTNEED)
	}

	progressWriter := &progressTracker{
		w:           &nullWriter{},
		callback:    d.ProgressCallback,
		path:        filepath.Dir(filepath.Dir(path)), // Get the drive path (parent of testUUID dir)
		phase:       "read",
		totalBytes:  d.FileSize,
		ioIndex:     ioIndex,
		startTime:   startTime,
	}

	n, err := copyAligned(progressWriter, r, data, int64(d.FileSize), r.Fd(), !useDirectIO)
	r.Close()
	if err != nil {
		return 0, err
	}
	if n != int64(d.FileSize) {
		return 0, fmt.Errorf("Expected read %d, read %d", d.FileSize, n)
	}

	// Drop any cached pages after reading to ensure future reads are also uncached
	if !useDirectIO {
		if f, err := os.Open(path); err == nil {
			unix.Fadvise(int(f.Fd()), 0, 0, unix.FADV_DONTNEED)
			f.Close()
		}
	}

	dt := float64(time.Since(startTime))
	throughputInSeconds := (float64(d.FileSize) / dt) * float64(time.Second)
	return uint64(throughputInSeconds), nil
}

// alignedBlock - pass through to directio implementation.
func alignedBlock(blockSize int) []byte {
	return directio.AlignedBlock(blockSize)
}

// fdatasync - fdatasync() is similar to fsync(), but does not flush modified metadata
// unless that metadata is needed in order to allow a subsequent data retrieval
// to  be  correctly  handled.   For example, changes to st_atime or st_mtime
// (respectively, time of last access and time of last modification; see inode(7))
// do not require flushing because they are not necessary for a subsequent data
// read to be handled correctly. On the other hand, a change to the file size
// (st_size, as made by say ftruncate(2)), would require a metadata flush.
//
// The aim of fdatasync() is to reduce disk activity for applications that
// do not require all metadata to be synchronized with the disk.
func fdatasync(fd int) error {
	return syscall.Fdatasync(fd)
}

func fadviseSequential(f *os.File, length int64) error {
	return unix.Fadvise(int(f.Fd()), 0, length, unix.FADV_SEQUENTIAL)
}

type nullReader struct {
	ctx context.Context
}

func (n nullReader) Read(b []byte) (int, error) {
	if n.ctx.Err() != nil {
		return 0, n.ctx.Err()
	}
	return len(b), nil
}

func newRandomReader(ctx context.Context) io.Reader {
	r, err := rng.NewReader()
	if err != nil {
		panic(err)
	}
	return r
}

// disableDirectIO - disables directio mode.
func disableDirectIO(fd uintptr) error {
	flag, err := unix.FcntlInt(fd, unix.F_GETFL, 0)
	if err != nil {
		return err
	}
	flag &= ^(syscall.O_DIRECT)
	_, err = unix.FcntlInt(fd, unix.F_SETFL, flag)
	return err
}

// DirectioAlignSize - DirectIO alignment needs to be 4K. Defined here as
// directio.AlignSize is defined as 0 in MacOS causing divide by 0 error.
const DirectioAlignSize = 4096

// copyAligned - copies from reader to writer using the aligned input
// buffer, it is expected that input buffer is page aligned to
// 4K page boundaries. Without passing aligned buffer may cause
// this function to return error.
//
// This code is similar in spirit to io.Copy but it is only to be
// used with DIRECT I/O based file descriptor and it is expected that
// input writer *os.File not a generic io.Writer. Make sure to have
// the file opened for writes with syscall.O_DIRECT flag.
//
// When syncMode is true, alignment checks are skipped as O_DSYNC/O_SYNC
// is used instead of O_DIRECT.
func copyAligned(w io.Writer, r io.Reader, alignedBuf []byte, totalSize int64, fd uintptr, syncMode bool) (int64, error) {
	if totalSize == 0 {
		return 0, nil
	}

	var written int64
	for {
		buf := alignedBuf
		if totalSize > 0 {
			remaining := totalSize - written
			if remaining < int64(len(buf)) {
				buf = buf[:remaining]
			}
		}

		// In sync mode, we don't need to worry about alignment since we're not using O_DIRECT
		if !syncMode && len(buf)%DirectioAlignSize != 0 {
			// Disable O_DIRECT on fd's on unaligned buffer
			// perform an amortized Fdatasync(fd) on the fd at
			// the end, this is performed by the caller before
			// closing 'w'.
			if err := disableDirectIO(fd); err != nil {
				return written, err
			}
		}

		nr, err := io.ReadFull(r, buf)
		eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
		if err != nil && !eof {
			return written, err
		}

		buf = buf[:nr]
		var (
			n  int
			un int
			nw int64
		)

		remain := len(buf) % DirectioAlignSize
		// In sync mode, treat all buffers as "aligned" (no special handling needed)
		if syncMode || remain == 0 {
			// buf is aligned for directio write() or we're in sync mode
			n, err = w.Write(buf)
			nw = int64(n)
		} else {
			if remain < len(buf) {
				n, err = w.Write(buf[:len(buf)-remain])
				if err != nil {
					return written, err
				}
				nw = int64(n)
			}

			// Disable O_DIRECT on fd's on unaligned buffer
			// perform an amortized Fdatasync(fd) on the fd at
			// the end, this is performed by the caller before
			// closing 'w'.
			if err = disableDirectIO(fd); err != nil {
				return written, err
			}

			// buf is not aligned, hence use writeUnaligned()
			// for the remainder
			un, err = w.Write(buf[len(buf)-remain:])
			nw += int64(un)
		}

		if nw > 0 {
			written += nw
		}

		if err != nil {
			return written, err
		}

		if nw != int64(len(buf)) {
			return written, io.ErrShortWrite
		}

		if totalSize > 0 && written == totalSize {
			// we have written the entire stream, return right here.
			return written, nil
		}

		if eof {
			// We reached EOF prematurely but we did not write everything
			// that we promised that we would write.
			if totalSize > 0 && written != totalSize {
				return written, io.ErrUnexpectedEOF
			}
			return written, nil
		}
	}
}

func (d *DrivePerf) runWriteTest(ctx context.Context, path string, data []byte) (uint64, error) {
	return d.runWriteTestWithIndex(ctx, path, data, 0)
}

func (d *DrivePerf) runWriteTestWithIndex(ctx context.Context, path string, data []byte, ioIndex int) (uint64, error) {
	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
		return 0, err
	}

	startTime := time.Now()

	// Choose flags based on sync mode
	var flags int
	if d.SyncMode {
		// Use O_DSYNC for synchronized writes (for small block sizes or when --sync is specified)
		flags = syscall.O_DSYNC | os.O_RDWR | os.O_CREATE | os.O_TRUNC
	} else {
		// Use O_DIRECT for direct I/O (bypasses page cache)
		flags = syscall.O_DIRECT | os.O_RDWR | os.O_CREATE | os.O_TRUNC
	}

	w, err := os.OpenFile(path, flags, 0o600)
	if err != nil {
		return 0, err
	}

	progressWriter := &progressTracker{
		w:           w,
		callback:    d.ProgressCallback,
		path:        filepath.Dir(filepath.Dir(path)), // Get the drive path (parent of testUUID dir)
		phase:       "write",
		totalBytes:  d.FileSize,
		ioIndex:     ioIndex,
		startTime:   startTime,
	}

	n, err := copyAligned(progressWriter, newRandomReader(ctx), data, int64(d.FileSize), w.Fd(), d.SyncMode)
	if err != nil {
		w.Close()
		return 0, err
	}

	if n != int64(d.FileSize) {
		w.Close()
		return 0, fmt.Errorf("Expected to write %d, wrote %d bytes", d.FileSize, n)
	}

	if err := fdatasync(int(w.Fd())); err != nil {
		return 0, err
	}

	if err := w.Close(); err != nil {
		return 0, err
	}

	dt := float64(time.Since(startTime))
	throughputInSeconds := (float64(d.FileSize) / dt) * float64(time.Second)
	return uint64(throughputInSeconds), nil
}

// progressTracker wraps an io.Writer and reports progress via callback
type progressTracker struct {
	w              io.Writer
	callback       ProgressCallback
	path           string
	phase          string
	totalBytes     uint64
	bytesProcessed uint64
	ioIndex        int
	startTime      time.Time
	lastReport     time.Time
}

func (p *progressTracker) Write(b []byte) (int, error) {
	n, err := p.w.Write(b)
	if err != nil {
		return n, err
	}

	p.bytesProcessed += uint64(n)

	// Report progress if callback is set
	// Rate-limit updates to avoid overwhelming the UI (report every 10ms)
	if p.callback != nil {
		now := time.Now()
		if now.Sub(p.lastReport) >= 10*time.Millisecond || p.bytesProcessed == p.totalBytes {
			dt := float64(now.Sub(p.startTime))
			throughput := uint64(0)
			if dt > 0 {
				throughput = uint64((float64(p.bytesProcessed) / dt) * float64(time.Second))
			}

			p.callback(ProgressUpdate{
				Path:           p.path,
				Phase:          p.phase,
				BytesProcessed: p.bytesProcessed,
				TotalBytes:     p.totalBytes,
				Throughput:     throughput,
				IOIndex:        p.ioIndex,
				Error:          nil,
			})
			p.lastReport = now
		}
	}

	return n, nil
}
