1 module calc;
2 
3 import ExprBaseVisitor : BaseVisitor = ExprBaseVisitor;
4 import ExprLexer : Lexer = ExprLexer;
5 import ExprParser : Parser = ExprParser;
6 import std.stdio;
7 import std.variant;
8 
9 version (unittest) import dshould;
10 
11 void eval(File file)
12 {
13     with (Setup(file))
14     {
15         auto ctx = parser.expr;
16 
17         if (parser.numberOfSyntaxErrors == 0)
18         {
19             const result = visitor.visit(ctx);
20 
21             writeln(result.get!int);
22         }
23     }
24 }
25 
26 struct Setup
27 {
28     Parser parser;
29 
30     Visitor visitor;
31 
32     this(T)(T input)
33     {
34         import antlr.v4.runtime.ANTLRInputStream : ANTLRInputStream;
35         import antlr.v4.runtime.CommonTokenStream : CommonTokenStream;
36 
37         auto inputStream = new ANTLRInputStream(input);
38         auto lexer = new Lexer(inputStream);
39         auto commonTokenStream = new CommonTokenStream(lexer);
40 
41         parser = new Parser(commonTokenStream);
42         visitor = new Visitor;
43     }
44 }
45 
46 class Visitor : BaseVisitor
47 {
48     override Variant visitLiteral(Parser.LiteralContext ctx)
49     {
50         import std.conv : to;
51 
52         return Variant(ctx.INT.getText.get!string.to!int);
53     }
54 
55     override Variant visitBinaryExpr(Parser.BinaryExprContext ctx)
56     {
57         auto lhs = visit(ctx.lhs);
58         auto rhs = visit(ctx.rhs);
59 
60         if (ctx.op.getText == "+")
61             return Variant(lhs + rhs);
62         if (ctx.op.getText == "-")
63             return Variant(lhs - rhs);
64         if (ctx.op.getText == "*")
65             return Variant(lhs * rhs);
66         if (ctx.op.getText == "/")
67             return Variant(lhs / rhs);
68         assert(0);
69     }
70 
71     override Variant visitParens(Parser.ParensContext ctx)
72     {
73         return visit(ctx.expr);
74     }
75 
76 }
77 
78 @("evaluate expression")
79 unittest
80 {
81     with (Setup("1 + 2 * 3"))
82     {
83         auto ctx = parser.expr;
84 
85         parser.numberOfSyntaxErrors.should.equal(0);
86 
87         const result = visitor.visit(ctx).get!int;
88 
89         result.should.equal(7);
90     }
91 }
92 
93 @("evaluate expression with parentheses")
94 unittest
95 {
96     with (Setup("(1 + 2) * 3"))
97     {
98         auto ctx = parser.expr;
99 
100         parser.numberOfSyntaxErrors.should.equal(0);
101 
102         const result = visitor.visit(ctx).get!int;
103 
104         result.should.equal(9);
105     }
106 }
107 
108 @("evaluate invalid expression")
109 unittest
110 {
111     with (Setup("1 + * 3"))
112     {
113         auto errorListener = new TestErrorListener;
114 
115         parser.addErrorListener(errorListener);
116 
117         auto ctx = parser.expr;
118 
119         parser.numberOfSyntaxErrors.should.equal(1);
120         errorListener.msg.should.equal("extraneous input '*' expecting {'(', INT}");
121     }
122 }
123 
124 version (unittest)
125 {
126     import antlr.v4.runtime.BaseErrorListener;
127     import antlr.v4.runtime.InterfaceRecognizer;
128     import antlr.v4.runtime.RecognitionException;
129 
130     class TestErrorListener : BaseErrorListener
131     {
132         string msg;
133 
134         override void syntaxError(InterfaceRecognizer recognizer, Object offendingSymbol, int line,
135             int charPositionInLine, string msg, RecognitionException e)
136         {
137             this.msg = msg;
138         }
139     }
140 }