1 module calc; 2 3 import ExprBaseVisitor : BaseVisitor = ExprBaseVisitor; 4 import ExprLexer : Lexer = ExprLexer; 5 import ExprParser : Parser = ExprParser; 6 import std.stdio; 7 import std.variant; 8 9 version (unittest) import dshould; 10 11 void eval(File file) 12 { 13 with (Setup(file)) 14 { 15 auto ctx = parser.expr; 16 17 if (parser.numberOfSyntaxErrors == 0) 18 { 19 const result = visitor.visit(ctx); 20 21 writeln(result.get!int); 22 } 23 } 24 } 25 26 struct Setup 27 { 28 Parser parser; 29 30 Visitor visitor; 31 32 this(T)(T input) 33 { 34 import antlr.v4.runtime.ANTLRInputStream : ANTLRInputStream; 35 import antlr.v4.runtime.CommonTokenStream : CommonTokenStream; 36 37 auto inputStream = new ANTLRInputStream(input); 38 auto lexer = new Lexer(inputStream); 39 auto commonTokenStream = new CommonTokenStream(lexer); 40 41 parser = new Parser(commonTokenStream); 42 visitor = new Visitor; 43 } 44 } 45 46 class Visitor : BaseVisitor 47 { 48 override Variant visitLiteral(Parser.LiteralContext ctx) 49 { 50 import std.conv : to; 51 52 return Variant(ctx.INT.getText.get!string.to!int); 53 } 54 55 override Variant visitBinaryExpr(Parser.BinaryExprContext ctx) 56 { 57 auto lhs = visit(ctx.lhs); 58 auto rhs = visit(ctx.rhs); 59 60 if (ctx.op.getText == "+") 61 return Variant(lhs + rhs); 62 if (ctx.op.getText == "-") 63 return Variant(lhs - rhs); 64 if (ctx.op.getText == "*") 65 return Variant(lhs * rhs); 66 if (ctx.op.getText == "/") 67 return Variant(lhs / rhs); 68 assert(0); 69 } 70 71 override Variant visitParens(Parser.ParensContext ctx) 72 { 73 return visit(ctx.expr); 74 } 75 76 } 77 78 @("evaluate expression") 79 unittest 80 { 81 with (Setup("1 + 2 * 3")) 82 { 83 auto ctx = parser.expr; 84 85 parser.numberOfSyntaxErrors.should.equal(0); 86 87 const result = visitor.visit(ctx).get!int; 88 89 result.should.equal(7); 90 } 91 } 92 93 @("evaluate expression with parentheses") 94 unittest 95 { 96 with (Setup("(1 + 2) * 3")) 97 { 98 auto ctx = parser.expr; 99 100 parser.numberOfSyntaxErrors.should.equal(0); 101 102 const result = visitor.visit(ctx).get!int; 103 104 result.should.equal(9); 105 } 106 } 107 108 @("evaluate invalid expression") 109 unittest 110 { 111 with (Setup("1 + * 3")) 112 { 113 auto errorListener = new TestErrorListener; 114 115 parser.addErrorListener(errorListener); 116 117 auto ctx = parser.expr; 118 119 parser.numberOfSyntaxErrors.should.equal(1); 120 errorListener.msg.should.equal("extraneous input '*' expecting {'(', INT}"); 121 } 122 } 123 124 version (unittest) 125 { 126 import antlr.v4.runtime.BaseErrorListener; 127 import antlr.v4.runtime.InterfaceRecognizer; 128 import antlr.v4.runtime.RecognitionException; 129 130 class TestErrorListener : BaseErrorListener 131 { 132 string msg; 133 134 override void syntaxError(InterfaceRecognizer recognizer, Object offendingSymbol, int line, 135 int charPositionInLine, string msg, RecognitionException e) 136 { 137 this.msg = msg; 138 } 139 } 140 }