finn-lang/Interpreter.cs

664 lines
16 KiB
C#

using System;
using System.Linq;
using System.Collections.Generic;
using System.Collections.Immutable;
using AST = Finn.AST;
using System.Runtime.CompilerServices;
using System.Net;
using Finn.AST;
namespace Finn;
public class RuntimeError : Exception
{
public readonly Token Token;
internal RuntimeError(Token token, String message) : base(message)
{
Token = token;
}
}
public class Env
{
private readonly Env? enclosing;
private readonly Dictionary<string, object> values = new Dictionary<string, object>();
public Env()
{
enclosing = null;
}
public Env(Env enclosing)
{
this.enclosing = enclosing;
}
public object this[string name]
{
set
{
// Redefinition within the same scope should never happen, but front
// end should catch it.
values[name] = value;
}
}
public object this[Token identifier]
{
get
{
var name = (string)identifier.Literal!;
try
{
return values[name];
}
catch
{
if (enclosing != null)
{
return enclosing[identifier];
}
throw new RuntimeError(identifier, $"Undefined variable {name}.");
}
}
}
}
public interface Callable
{
object Call(Interpreter interpreter, object[] arguments);
}
public record NativeFunction(Func<object[], object> Function) : Callable
{
public object Call(Interpreter interpreter, object[] arguments)
{
return this.Function(arguments);
}
public override string ToString() => "<native func>";
}
public class Interpreter : AST.IExprVisitor<Env, object>
{
private class PatternMismatchException : Exception
{
public readonly Token Start;
public PatternMismatchException(Token start, string message) : base(message)
{
Start = start;
}
}
private class VariantTagMismatchException : PatternMismatchException
{
public VariantTagMismatchException(Token start, string patternTag, string valueTag) : base(start, $"Pattern tag {patternTag} does not match value tag {valueTag}.") { }
}
class PatternBinder : AST.IPatternVisitor<(object, Env), ValueTuple>
{
public ValueTuple visitFieldPatternPattern((object, Env) context, AST.FieldPattern pattern)
{
return (pattern.Pattern ?? new AST.SimplePattern(pattern.Name, pattern.Name)).accept(context, this);
}
public ValueTuple visitRecordPatternPattern((object, Env) context, AST.RecordPattern pattern)
{
var (obj, env) = context;
switch (obj)
{
case Record r:
var removedLabels = new List<string>();
foreach (var field in pattern.Fields)
{
string name = (string)field.Name.Literal!;
removedLabels.Add(name);
var fieldValue = r.Get(name);
field.accept((fieldValue, env), this);
}
if (pattern.Rest != null)
{
var rest = r.Without(removedLabels);
pattern.Rest.accept((rest, env), this);
}
return ValueTuple.Create();
}
throw new PatternMismatchException(pattern.Start, "Matched value {obj} is not a record.");
}
public ValueTuple visitSimplePatternPattern((object, Env) context, AST.SimplePattern pattern)
{
var (obj, env) = context;
if (pattern.Identifier == null)
{
return ValueTuple.Create();
}
env[(string)pattern.Identifier.Literal!] = obj;
return ValueTuple.Create();
}
public ValueTuple visitVariantPatternPattern((object, Env) context, AST.VariantPattern pattern)
{
var (obj, env) = context;
switch (obj)
{
case Variant v:
var tag = (string)pattern.Tag.Literal!;
if (v.Tag != tag)
{
throw new VariantTagMismatchException(pattern.Start, tag, v.Tag);
}
if (v.Value == null && pattern.Argument == null)
{
return ValueTuple.Create();
}
if (v.Value != null && pattern.Argument != null)
{
return pattern.Argument.accept((v.Value, env), this);
}
throw new PatternMismatchException(pattern.Start, "Variant pattern arity does not match variant value.");
}
throw new PatternMismatchException(pattern.Start, $"Matched value {obj} is not a variant.");
}
}
private class List
{
public required ImmutableList<object> Items { get; init; }
public static List Empty = new List { Items = ImmutableList<object>.Empty };
public object this[double i]
{
get { return Items[(int)i]; }
}
public bool IsEmpty
{
get { return Items.IsEmpty; }
}
public List Add(object item)
{
return new List { Items = Items.Add(item) };
}
public override string ToString()
{
System.IO.StringWriter sw = new System.IO.StringWriter();
sw.Write("[");
foreach (var item in this.Items)
{
sw.Write($" {item},");
}
sw.Write(" ]");
return sw.ToString();
}
}
private class Record
{
public static Record Empty = new Record
{
Fields = ImmutableSortedDictionary<string, ImmutableStack<object>>.Empty,
};
public required ImmutableSortedDictionary<string, ImmutableStack<object>> Fields { get; init; }
public Record Update(string name, object value)
{
ImmutableStack<object>? values;
if (!Fields.TryGetValue(name, out values))
{
throw new ArgumentException($"no such field: {name}");
}
return new Record { Fields = Fields.SetItem(name, values!.Pop().Push(value)) };
}
public object Get(string name)
{
ImmutableStack<object>? values;
if (!Fields.TryGetValue(name, out values))
{
throw new ArgumentException($"no such field: {name}");
}
return values.Peek();
}
public Record Without(IEnumerable<string> labels)
{
return new Record { Fields = Fields.RemoveRange(labels) };
}
public Record Remove(string name)
{
ImmutableStack<object>? values;
if (!Fields.TryGetValue(name, out values))
{
throw new ArgumentException($"no such field: {name}");
}
var value = values.Peek();
var popped = values.Pop();
if (popped.IsEmpty)
{
return new Record { Fields = Fields.Remove(name) };
}
return new Record { Fields = Fields.SetItem(name, popped) };
}
public Record Extend(string name, object value)
{
var values = Fields.GetValueOrDefault(name, ImmutableStack<object>.Empty);
return new Record { Fields = Fields.SetItem(name, values.Push(value)) };
}
public override string ToString()
{
System.IO.StringWriter sw = new System.IO.StringWriter();
sw.Write("{");
foreach ((var label, var values) in this.Fields)
{
foreach (var value in values)
{
sw.Write($" {label} = {value},");
}
}
sw.Write(" }");
return sw.ToString();
}
}
private record Variant(string Tag, object? Value)
{
public static readonly Variant True = new Variant("true", null);
public static readonly Variant False = new Variant("false", null);
public static Variant FromBool(bool b)
{
return b ? True : False;
}
public bool IsEmpty { get { return Value == null; } }
public override string ToString()
{
if (Value == null)
{
return $"`{Tag}";
}
else
{
return $"`{Tag}({Value})";
}
}
}
public record FinnFunction(FuncBinding binding, Env closure) : Callable
{
public object Call(Interpreter interpreter, object[] arguments)
{
Env env = new Env(closure);
for (int i = 0; i < binding.Params.Length; i++)
{
binding.Params[i].accept((arguments[i], env), new PatternBinder());
}
return interpreter.evaluate(env, binding.Value);
// throw new NotImplementedException();
}
}
protected internal readonly Env Globals = new Env();
public Interpreter()
{
Globals["clock"] = new NativeFunction((args) =>
{
return (double)DateTimeOffset.Now.ToUnixTimeSeconds();
});
}
public void Interpret(AST.Expr expression)
{
try
{
var value = evaluate(Globals, expression);
Console.WriteLine(value);
}
catch (RuntimeError err)
{
Program.runtimeError(err);
}
}
private object evaluate(Env env, AST.Expr expr)
{
return expr.accept(env, this);
}
private List checkListOperand(Token op, Object operand)
{
if (operand is List l) return l;
throw new RuntimeError(op, "Operand must be a record.");
}
private Record checkRecordOperand(Token op, Object operand)
{
if (operand is Record r) return r;
throw new RuntimeError(op, "Operand must be a record.");
}
private double checkNumberOperand(Token op, Object operand)
{
if (operand is double d) return d;
throw new RuntimeError(op, "Operand must be a number.");
}
private void checkNumberOperands(Token op, object left, object right)
{
if (left is double && right is double) return;
throw new RuntimeError(op, "Operands must be numbers.");
}
private void checkStringOperands(Token op, object left, object right)
{
if (left is string && right is string) return;
throw new RuntimeError(op, "Operands must be strings.");
}
private Variant checkBoolOperand(Token op, object operand)
{
if (operand is Variant v)
{
if (v == Variant.True || v == Variant.False) return v;
}
throw new RuntimeError(op, "Operand must be <true,false>.");
}
public object visitBinaryExpr(Env env, AST.Binary expr)
{
var left = evaluate(env, expr.Left);
var right = evaluate(env, expr.Right);
switch (expr.Op.Type)
{
case TokenType.Minus:
checkNumberOperands(expr.Op, left, right);
return (double)left - (double)right;
case TokenType.Plus:
checkNumberOperands(expr.Op, left, right);
return (double)left + (double)right;
case TokenType.Slash:
checkNumberOperands(expr.Op, left, right);
return (double)left / (double)right;
case TokenType.Asterisk:
checkNumberOperands(expr.Op, left, right);
return (double)left * (double)right;
case TokenType.PlusPlus:
checkStringOperands(expr.Op, left, right);
return (string)left + (string)right;
case TokenType.Greater:
checkNumberOperands(expr.Op, left, right);
return Variant.FromBool((double)left > (double)right);
case TokenType.GreaterEqual:
checkNumberOperands(expr.Op, left, right);
return Variant.FromBool((double)left >= (double)right);
case TokenType.Less:
checkNumberOperands(expr.Op, left, right);
return Variant.FromBool((double)left < (double)right);
case TokenType.LessEqual:
checkNumberOperands(expr.Op, left, right);
return Variant.FromBool((double)left <= (double)right);
case TokenType.BangEqual:
return Variant.FromBool((left, right) switch
{
(String l, String r) => l != r,
(double l, double r) => l != r,
_ => throw new ArgumentException(),
});
case TokenType.DoubleEqual:
return Variant.FromBool((left, right) switch
{
(String l, String r) => l == r,
(double l, double r) => l == r,
_ => throw new ArgumentException(),
});
}
throw new ArgumentException($"bad binary op: {expr.Op}");
}
public object visitCallExpr(Env env, AST.Call expr)
{
var callee = evaluate(env, expr.Left);
object[] args = new object[expr.Arguments.Length];
for (int i = 0; i < args.Length; i++)
{
var argExpr = expr.Arguments[i];
if (argExpr is null)
{
throw new NotImplementedException("partial application not implemented");
}
args[i] = evaluate(env, argExpr);
}
Callable function = (Callable)callee;
return function.Call(this, args);
}
public object visitGroupingExpr(Env env, AST.Grouping expr)
{
return evaluate(env, expr.Expression);
}
public object visitVariableExpr(Env env, AST.Variable expr)
{
return env[expr.Value];
}
public object visitIfExpr(Env env, AST.If expr)
{
var cond = evaluate(env, expr.Condition);
var vb = checkBoolOperand(expr.Condition.Start, cond);
if (vb == Variant.True)
{
return evaluate(env, expr.Then);
}
else
{
return evaluate(env, expr.Else);
}
}
public object visitIndexerExpr(Env env, AST.Indexer expr)
{
var left = checkListOperand(expr.Left.Start, evaluate(env, expr.Left));
var index = checkNumberOperand(expr.Index.Start, evaluate(env, expr.Index));
try
{
var item = left[index];
return new Variant("some", item);
}
catch { return new Variant("nothing", null); }
}
public object visitLetExpr(Env env, Let expr)
{
var newEnv = new Env(env);
foreach (var binding in expr.Bindings)
{
switch (binding)
{
case VarBinding(var pattern, var valueExpr):
// By passing newEnv, we let the var definition refer to
// earlier bindings in the list.
var value = evaluate(newEnv, valueExpr);
try
{
pattern.accept((value, newEnv), new PatternBinder());
}
catch (Exception e)
{
var start = e is PatternMismatchException ?
((PatternMismatchException)e).Start :
binding.Start;
throw new RuntimeError(start, e.Message);
}
break;
case FuncBinding fb:
newEnv[(string)fb.Name.Literal!] = new FinnFunction(fb, newEnv);
break;
default:
throw new Exception("wtf there are no other binding types");
}
}
return evaluate(newEnv, expr.Body);
}
public object visitListExpr(Env env, AST.List expr)
{
List l = List.Empty;
foreach (var itemExpr in expr.Elements)
{
l = l.Add(evaluate(env, itemExpr));
}
return l;
}
public object visitLiteralExpr(Env env, AST.Literal expr)
{
return expr.Value;
}
public object visitRecordExpr(Env env, AST.Record expr)
{
Record rec = Record.Empty;
if (expr.Base != null)
{
var baseRecValue = evaluate(env, expr.Base.Value);
if (baseRecValue is not Record)
{
throw new RuntimeError(expr.Base.Value.Start, "Base record must be a record.");
}
var baseRec = (Record)baseRecValue;
// Updates
HashSet<string> updateLabels = new HashSet<string>();
foreach (AST.Field update in expr.Base.Updates)
{
var label = (string)update.Name.Literal!;
if (updateLabels.Contains(label))
{
throw new RuntimeError(update.Name, "Record updates must be to unique fields.");
}
updateLabels.Add(label);
(var updateValue, var updateValueToken) =
update.Value == null ?
(env[update.Name], update.Name) :
(evaluate(env, update.Value), update.Value.Start);
try
{
baseRec = baseRec.Update(label, updateValue);
}
catch
{
throw new RuntimeError(updateValueToken, "Field update must have same type as previous value.");
}
}
rec = baseRec;
}
// Extensions
HashSet<string> extLabels = new HashSet<string>();
foreach (AST.Field extension in expr.Extensions)
{
var label = (string)extension.Name.Literal!;
if (extLabels.Contains(label))
{
throw new RuntimeError(extension.Name, "Record extensions must have unique field names.");
}
extLabels.Add(label);
var extensionValue = extension.Value == null ? env[extension.Name] : evaluate(env, extension.Value);
rec = rec.Extend(label, extensionValue);
}
return rec;
}
public object visitSelectorExpr(Env env, AST.Selector expr)
{
var left = evaluate(env, expr.Left);
var r = checkRecordOperand(expr.Left.Start, left);
try
{
return r.Get((string)expr.FieldName.Literal!);
}
catch
{
throw new RuntimeError(expr.FieldName, "Operand must have selected field.");
}
}
public object visitSequenceExpr(Env env, AST.Sequence expr)
{
evaluate(env, expr.Left);
return evaluate(env, expr.Right);
}
public object visitUnaryExpr(Env env, AST.Unary expr)
{
object right = evaluate(env, expr.Right);
switch (expr.Op.Type)
{
case TokenType.Minus:
checkNumberOperand(expr.Op, right);
return -(double)right;
case TokenType.Bang:
if (checkBoolOperand(expr.Op, right) == Variant.True)
{
return Variant.False;
}
else
{
return Variant.True;
}
default:
// Unreachable
throw new Exception($"bad unary op: {expr.Op}");
}
}
public object visitVariantExpr(Env env, AST.Variant expr)
{
var tag = (string)expr.Tag.Literal!;
if (expr.Argument == null)
{
return new Variant(tag, null);
}
return new Variant(tag, evaluate(env, expr.Argument));
}
public object visitWhenExpr(Env env, AST.When expr)
{
var head = evaluate(env, expr.Head);
foreach (var c in expr.Cases)
{
try
{
var newEnv = new Env(env);
c.Pattern.accept((head, newEnv), new PatternBinder());
return evaluate(newEnv, c.Value);
}
catch (VariantTagMismatchException)
{
continue;
}
catch (PatternMismatchException e)
{
throw new RuntimeError(e.Start, e.Message);
}
}
throw new RuntimeError(expr.Start, "No matching patterns.");
}
}