# TODO: examples
# FIXME: error messages

class ParseError(Exception):
    def __init__(self, msg=""):
        self.msg = msg
    def __str__(self):
        return self.msg

class Parser:
    """Parser :: (( String           -- Input String
                  , Int              -- Cursor position within input
                  , (a -> Int -> r)  -- Success continuation
                  , (err -> r)       -- Failure continuation
                  ) -> r) -> Parser
    """

    def __init__(self, go):
        self._go = go

    def __str__(self):
        """ __str__ :: Parser -> String """
        return "<Parser>"

    def parse(self,inp):
        """parse :: Parser a -> String -> a

        Run the given parser for the given input string, returning the given
        value or raising `ParseError` on an failing parse.

        Arguments:
        inp -- the input string
        """
        def cont(v,p): return v
        def fail(err): raise ParseError(err)

        return self._go(inp,0,cont,fail)

    def __call__(self,inp):
        """See `parse`"""
        return self.parse(inp)

    def map(self, f):
        """map :: (a -> b) -> Parser a -> Parser b

        Modify the return value of the given parser by applying the given
        function

        Arguments:
        f -- the function to apply

        Returns: A parser which returns the modified value when applied
        """
        def _go(inp,pos,cont,fail):
            def newCont(val,pos):
                return cont(f(val),pos)
            return self._go(inp,pos,newCont,fail)
        return Parser(_go)

    def asThough(self,v):
        """as :: b -> Parser a -> Parser b

        Override the return value of the given parser to use the given value
        instead
        """
        return self.map(lambda _:v)

    def bind(self, f):
        """bind :: Parser a -> (a -> Parser b) -> Parser b

        Construct a parser from a given parser and a function which has access
        to it's inner value

        Arguments:
        f -- the binding function

        Returns: The Parser constructed from the binding func.
        """
        def _go(inp,pos,cont,fail):
            def newCont(val,npos):
                return f(val)._go(inp,npos,cont,fail)
            return self._go(inp,pos,newCont,fail)
        return Parser(_go)

    def __rshift__(self, other):
        """See `bind`"""
        return self.bind(other)

    def then(self,other):
        """Sequence two parsers, ignoring the left result"""
        return self.bind(lambda _:lift(other))

    def over(self,tail):
        """Sequence two parsers, ignoring the right result"""
        return self.bind(lambda v: lift(tail).asThough(v))

    def orElse(self, other):
        """orElse :: Parser a -> Parser b -> Parser (Either a b)

        Attempt to apply the left parser, and then the right parser if the left
        one fails to parse.

        Arguments:
        other -- A parser to apply if the left parser fails

        Returns: A Parser which accepts anything accepted by either left or
        right.
        """
        def _go(inp,pos,cont,fail):
            def leftFail(leftMsg):
                def rightFail(rightMsg):
                    return fail(leftMsg+rightMsg)
                return other._go(inp,pos,cont,rightFail)
            return self._go(inp,pos,cont,leftFail)
        return Parser(_go)

    def __or__(self,other):
        """See `orElse`"""
        return self.orElse(other)

    def desc(self,msg):
        """Attach a message to the given parser to use for descriptive error
        messages on failure
        """
        def _go(inp,pos,cont,fail):
            def newFail(_):
                return fail(msg)
            return self._go(inp,pos,cont,newFail)
        return Parser(_go)

def fail(msg):
    """Return a parser which always fails with the given message"""
    def _go(inp,pos,cont,fail):
        return fail(msg)
    return Parser(_go)

def pure(val):
    """pure :: a -> Parser a

    Return a parser which always returns with the given value, consuming no
    input
    """
    def _go(inp,pos,cont,fail):
        return cont(val,pos)
    return Parser(_go)

def optional(prs, default=None):
    """optional :: Parser a -> Parser (Maybe a)

    Return a parser which attempts to apply the inner parser. On failure,
    continue as if no input was consumed with the given default value.

    Arguments:
    prs -- the parser which may or may not succeed
    default -- the value to return if the inner parser fails (default: None)
    """
    def _go(inp,pos,cont,fail):
        def newFail(_):
            return cont(default,pos)
        return lift(prs)._go(inp,pos,cont,newFail)
    return Parser(_go)

def many(prs):
    """many :: Parser a -> Parser [a]

    Construct a parser which applies the given parser as many times as possible
    (zero or more), returning collected results in a list.

    NB: many(many(whatever)) ==> stack overflow #FIXME
    """
    accum = []
    def _go(inp,pos,cont,fail):
        def newFail(_):
            res = accum[:]
            accum.clear()
            return cont(res, pos)
        def newCont(a,pos):
            accum.append(a)
            return lift(prs)._go(inp,pos,newCont,newFail)
        return lift(prs)._go(inp,pos,newCont,newFail)
    return Parser(_go)

def many1(prs):
    """As `many` except inner parser is applied at least once, or fail
    completely.
    """
    return lift(prs).bind(lambda v: many(prs).map(lambda vs : [v]+vs))

def manyTill(prs, end):
    """Repeatedly apply a given parser until a secondary end parser succeeds.
    Return values are accumulated and returns as a list
    """
    END_PARSER_SENTINEL = object()

    def _inBind(v):
        if v is END_PARSER_SENTINEL: #not at eol
            return lift(prs).bind(lambda val : manyTill(prs, end).map(lambda vs, val=val : [val]+vs))
        else: #at eol
            return pure([])

    return optional(lift(end), default=END_PARSER_SENTINEL).bind(_inBind)

def choice(*prss):
    """Construct a parser by `orElse`ing the given list of parsers together."""
    if len(prss) == 1 and type(prss[0]) is list:
        prss = prss[0]

    p = fail("choice: empty list of choices")
    for prs in prss:
        p = p.orElse(lift(prs))
    return p

def sequence(*prss):
    """Construct a parser by `bind`ing the given list of parsers together. The
    results of the inner parsers are returned as an in-order list
    """
    if len(prss) == 1 and type(prss[0]) is list:
        prss = prss[0]

    p = pure([])
    for prs in prss:
        p = p.bind(lambda l, prs=prs : lift(prs).map(lambda v : l+[v]))
    return p

def satisfy(pred,desc=None):
    """satisfy :: (Char -> Bool) -> Parser Char

    Construct a parser which succeeds and returns the next char if it satisfies
    the given predicate. Otherwise fail.

    Arguments:
    pred -- a predicate to apply to the next char
    desc -- a description of what 'success' looks like (default: None)
    """

    def _go(inp,pos,cont,fail):
        if pos < len(inp):
            c = inp[pos]
            if pred(c):
                return cont(c, pos+1)
        return fail(desc)
    return Parser(_go)

def char(c):
    """Construct a parser for the given char"""
    assert type(c) is str and len(c) == 1, "char: input must be a single char"
    return satisfy(lambda d: d==c, "expected: '"+c+"'")

def string(s):
    """Construct a parser for the given string"""
    p = pure([])
    for ch in s:
        p = p.then(char(ch))
    return p.asThough(s)

def lift(p):
    """Convert an argument of arbitrary type to the coresponding parser"""
    pType = type(p)
    if pType is Parser:
        return p
    elif pType is str:
        return string(p)
    elif pType is list:
        return sequence(lift(q) for q in p)
    elif "__str__" in p:
        return lift(str(p))
    else:
        raise TypeError("lift: can't lift value of type `"+pType+"`")


#------ Examples ------#
anyChar = satisfy(lambda _:True,"any char")

def toNum(s):
    s = ''.join(s)
    try:
        return int(s)
    except ValueError:
        try:
            return float(s)
        except ValueError: #Give up; return unchanged
            return s

def isDigit(c):
    return c.isdigit() or (c in ".+-")

number = many1(satisfy(isDigit, "digit")).map(toNum).desc("number")

EOL = object()
eol = satisfy(lambda c:c == '\n', "EOL").asThough(EOL)

whitespace = many1(satisfy(lambda c:c.isspace())).desc("whitespace")

def token(p):
    return optional(whitespace).then(p)

somestring = manyTill(anyChar, whitespace)

