Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 01a19cf

Browse files
committed
Add support for field accesses and constructors
1 parent b5b0a92 commit 01a19cf

File tree

10 files changed

+202
-43
lines changed

10 files changed

+202
-43
lines changed

‎compiler/src/dotty/tools/dotc/ast/untpd.scala‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
469469
def New(tpt: Tree, argss: List[List[Tree]])(using Context): Tree =
470470
ensureApplied(argss.foldLeft(makeNew(tpt))(Apply(_, _)))
471471

472-
/** A new expression with constrictor and possibly type arguments. See
472+
/** A new expression with constructor and possibly type arguments. See
473473
* `New(tpt, argss)` for details.
474474
*/
475475
def makeNew(tpt: Tree)(using Context): Tree = {

‎compiler/src/dotty/tools/dotc/core/Definitions.scala‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ class Definitions {
661661
@tu lazy val StringClass: ClassSymbol = requiredClass("java.lang.String")
662662
def StringType: Type = StringClass.typeRef
663663
@tu lazy val StringModule: Symbol = StringClass.linkedClass
664+
@tu lazy val String_== : TermSymbol = enterMethod(StringClass, nme.EQ, methOfAnyRef(BooleanType), Final)
664665
@tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final)
665666
@tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match {
666667
case List(pt) => pt.isAny || pt.stripNull().isAnyRef

‎compiler/src/dotty/tools/dotc/qualified_types/EGraph.scala‎

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dotty.tools.dotc.qualified_types
22

33
import scala.collection.mutable
44
import scala.collection.mutable.ArrayBuffer
5+
import scala.collection.mutable.ListBuffer
56

67
import dotty.tools.dotc.ast.tpd.{
78
closureDef,
@@ -25,13 +26,16 @@ import dotty.tools.dotc.core.Constants.Constant
2526
import dotty.tools.dotc.core.Contexts.Context
2627
import dotty.tools.dotc.core.Contexts.ctx
2728
import dotty.tools.dotc.core.Decorators.i
29+
import dotty.tools.dotc.core.Flags
2830
import dotty.tools.dotc.core.Hashable.Binders
2931
import dotty.tools.dotc.core.Names.Designator
3032
import dotty.tools.dotc.core.StdNames.nme
3133
import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol}
3234
import dotty.tools.dotc.core.Types.{
35+
AppliedType,
3336
CachedProxyType,
3437
ConstantType,
38+
LambdaType,
3539
MethodType,
3640
NamedType,
3741
NoPrefix,
@@ -40,14 +44,14 @@ import dotty.tools.dotc.core.Types.{
4044
TermParamRef,
4145
TermRef,
4246
Type,
47+
TypeRef,
4348
TypeVar,
4449
ValueType
4550
}
4651
import dotty.tools.dotc.qualified_types.ENode.Op
4752
import dotty.tools.dotc.reporting.trace
4853
import dotty.tools.dotc.transform.TreeExtractors.BinaryOp
4954
import dotty.tools.dotc.util.Spans.Span
50-
import scala.collection.mutable.ListBuffer
5155

5256
final class EGraph(rootCtx: Context):
5357

@@ -92,7 +96,7 @@ final class EGraph(rootCtx: Context):
9296
private val builtinOps = Map(
9397
d.Int_== -> Op.Equal,
9498
d.Boolean_== -> Op.Equal,
95-
d.Any_== -> Op.Equal,
99+
d.String_== -> Op.Equal,
96100
d.Boolean_&& -> Op.And,
97101
d.Boolean_|| -> Op.Or,
98102
d.Boolean_! -> Op.Not,
@@ -108,9 +112,8 @@ final class EGraph(rootCtx: Context):
108112

109113
def equiv(node1: ENode, node2: ENode)(using Context): Boolean =
110114
trace(i"EGraph.equiv", Printers.qualifiedTypes):
111-
val margin = ctx.base.indentTab * (ctx.base.indent)
115+
//val margin = ctx.base.indentTab * (ctx.base.indent)
112116
// println(s"$margin node1: $node1\n$margin node2: $node2")
113-
// Check if the representents of both nodes are the same
114117
val repr1 = representent(node1)
115118
val repr2 = representent(node2)
116119
repr1 eq repr2
@@ -121,8 +124,8 @@ final class EGraph(rootCtx: Context):
121124
node match
122125
case ENode.Atom(tp) =>
123126
()
124-
case ENode.New(clazz) =>
125-
addUse(clazz, node)
127+
case ENode.Constructor(sym) =>
128+
()
126129
case ENode.Select(qual, member) =>
127130
addUse(qual, node)
128131
case ENode.Apply(fn, args) =>
@@ -138,6 +141,7 @@ final class EGraph(rootCtx: Context):
138141
}
139142
).asInstanceOf[node.type]
140143

144+
// TODO(mbovel): Memoize this
141145
def toNode(tree: Tree, paramSyms: List[Symbol] = Nil, paramTps: List[ENode.ArgRefType] = Nil)(using
142146
Context
143147
): Option[ENode] =
@@ -165,16 +169,18 @@ final class EGraph(rootCtx: Context):
165169
tree match
166170
case Literal(_) | Ident(_) | This(_) if tree.tpe.isInstanceOf[SingletonType] =>
167171
Some(ENode.Atom(mapType(tree.tpe).asInstanceOf[SingletonType]))
168-
case New(clazz) =>
169-
for clazzNode <- toNode(clazz, paramSyms, paramTps) yield ENode.New(clazzNode)
172+
case Select(New(_), nme.CONSTRUCTOR) =>
173+
constructorNode(tree.symbol)
174+
case tree: Select if isCaseClassApply(tree.symbol) =>
175+
constructorNode(tree.symbol.owner.linkedClass.primaryConstructor)
170176
case Select(qual, name) =>
171-
for qualNode <- toNode(qual, paramSyms, paramTps) yield ENode.Select(qualNode, tree.symbol)
177+
for qualNode <- toNode(qual, paramSyms, paramTps) yield normalizeSelect(qualNode, tree.symbol)
172178
case BinaryOp(lhs, op, rhs) if builtinOps.contains(op) =>
173179
for
174180
lhsNode <- toNode(lhs, paramSyms, paramTps)
175181
rhsNode <- toNode(rhs, paramSyms, paramTps)
176182
yield normalizeOp(builtinOps(op), List(lhsNode, rhsNode))
177-
case BinaryOp(lhs, d.Int_-, rhs) if lhs.tpe.isInstanceOf[ValueType] && rhs.tpe.isInstanceOf[ValueType] =>
183+
case BinaryOp(lhs, d.Int_-, rhs) =>
178184
for
179185
lhsNode <- toNode(lhs, paramSyms, paramTps)
180186
rhsNode <- toNode(rhs, paramSyms, paramTps)
@@ -192,7 +198,7 @@ final class EGraph(rootCtx: Context):
192198
case mt: MethodType =>
193199
assert(defDef.termParamss.size == 1, "closures have a single parameter list, right?")
194200
val myParamSyms: List[Symbol] = defDef.termParamss.head.map(_.symbol)
195-
val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty
201+
val myParamTps: ListBuffer[ENode.ArgRefType] = ListBuffer.empty
196202
val paramTpsSize = paramTps.size
197203
for myParamSym <- myParamSyms do
198204
val underlying = mapType(myParamSym.info.subst(myParamSyms.take(myParamTps.size), myParamTps.toList))
@@ -204,15 +210,38 @@ final class EGraph(rootCtx: Context):
204210
case _ =>
205211
None
206212

213+
// TODO(mbovel): Memoize this
214+
private def constructorNode(constr: Symbol)(using Context): Option[ENode.Constructor] =
215+
val clazz = constr.owner
216+
if hasStructuralEquality(clazz) then
217+
val isPrimaryConstructor = constr.denot.isPrimaryConstructor
218+
val fieldsRaw = clazz.denot.asClass.paramAccessors.filter(isPrimaryConstructor && _.isStableMember)
219+
val constrParams = constr.paramSymss.flatten.filter(_.isTerm)
220+
val fields = constrParams.map(p => fieldsRaw.find(_.name == p.name).getOrElse(NoSymbol))
221+
Some(ENode.Constructor(constr)(fields))
222+
else
223+
None
224+
225+
private def hasStructuralEquality(clazz: Symbol)(using Context): Boolean =
226+
val equalsMethod = clazz.info.decls.lookup(nme.equals_)
227+
val equalsNotOverriden = !equalsMethod.exists || equalsMethod.is(Flags.Synthetic)
228+
clazz.isClass && clazz.is(Flags.Case) && equalsNotOverriden
229+
230+
private def isCaseClassApply(meth: Symbol)(using Context): Boolean =
231+
meth.name == nme.apply
232+
&& meth.flags.is(Flags.Synthetic)
233+
&& meth.owner.linkedClass.is(Flags.Case)
234+
207235
private def canonicalize(node: ENode): ENode =
236+
// println(s"canonicalize $node")
208237
representent(unique(
209238
node match
210239
case ENode.Atom(tp) =>
211240
node
212-
case ENode.New(clazz) =>
213-
ENode.New(representent(clazz))
241+
case ENode.Constructor(sym) =>
242+
node
214243
case ENode.Select(qual, member) =>
215-
ENode.Select(representent(qual), member)
244+
normalizeSelect(representent(qual), member)
216245
case ENode.Apply(fn, args) =>
217246
ENode.Apply(representent(fn), args.map(representent))
218247
case ENode.OpApply(op, args) =>
@@ -223,6 +252,33 @@ final class EGraph(rootCtx: Context):
223252
ENode.Lambda(paramTps, retTp, representent(body))
224253
))
225254

255+
private def normalizeSelect(qual: ENode, member: Symbol): ENode =
256+
getAppliedConstructor(qual) match
257+
case Some(constr) =>
258+
val memberIndex = constr.fields.indexOf(member)
259+
260+
if memberIndex >= 0 then
261+
val args = getTermArguments(qual)
262+
assert(args.size == constr.fields.size)
263+
args(memberIndex)
264+
else
265+
ENode.Select(qual, member)
266+
case None =>
267+
ENode.Select(qual, member)
268+
269+
private def getAppliedConstructor(node: ENode): Option[ENode.Constructor] =
270+
node match
271+
case ENode.Apply(fn, args) => getAppliedConstructor(fn)
272+
case ENode.TypeApply(fn, args) => getAppliedConstructor(fn)
273+
case node: ENode.Constructor => Some(node)
274+
case _ => None
275+
276+
private def getTermArguments(node: ENode): List[ENode] =
277+
node match
278+
case ENode.Apply(fn, args) => getTermArguments(fn) ::: args
279+
case ENode.TypeApply(fn, args) => getTermArguments(fn)
280+
case _ => Nil
281+
226282
private def normalizeOp(op: ENode.Op, args: List[ENode]): ENode =
227283
op match
228284
case Op.Equal =>
@@ -316,12 +372,10 @@ final class EGraph(rootCtx: Context):
316372
(a, b) match
317373
case (ENode.Atom(_: ConstantType), _) => (a, b)
318374
case (_, ENode.Atom(_: ConstantType)) => (b, a)
319-
case (ENode.Atom(_: SkolemType), _) => (a, b)
320-
case (_, ENode.Atom(_: SkolemType)) => (b, a)
375+
case (_: ENode.Constructor, _) => (a, b)
376+
case (_, _: ENode.Constructor) => (b, a)
321377
case (_: ENode.Atom, _) => (a, b)
322378
case (_, _: ENode.Atom) => (b, a)
323-
case (_: ENode.New, _) => (a, b)
324-
case (_, _: ENode.New) => (b, a)
325379
case (_: ENode.Select, _) => (a, b)
326380
case (_, _: ENode.Select) => (b, a)
327381
case (_: ENode.Apply, _) => (a, b)
@@ -336,8 +390,6 @@ final class EGraph(rootCtx: Context):
336390
if aRepr eq bRepr then return
337391
assert(aRepr != bRepr, s"$aRepr and $bRepr are `equals` but not `eq`")
338392

339-
// TODO(mbovel): if both nodes are objects, recursively merge their arguments
340-
341393
/// Update represententOf and usedBy maps
342394
val (newRepr, oldRepr) = order(aRepr, bRepr)
343395
represententOf(oldRepr) = newRepr
@@ -371,8 +423,9 @@ final class EGraph(rootCtx: Context):
371423
node match
372424
case ENode.Atom(tp) =>
373425
singleton(tp)
374-
case ENode.New(clazz) =>
375-
New(toTree(clazz, paramRefs))
426+
case ENode.Constructor(sym) =>
427+
val tycon = sym.owner.info.typeConstructor
428+
New(tycon).select(TermRef(tycon, sym))
376429
case ENode.Select(qual, member) =>
377430
toTree(qual, paramRefs).select(member)
378431
case ENode.Apply(fn, args) =>

‎compiler/src/dotty/tools/dotc/qualified_types/ENode.scala‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ enum ENode:
2323
import ENode.*
2424

2525
case Atom(tp: SingletonType)
26-
case New(clazz: ENode)
26+
case Constructor(constr: Symbol)(valfields:List[Symbol])
2727
case Select(qual: ENode, member: Symbol)
2828
case Apply(fn: ENode, args: List[ENode])
2929
case OpApply(fn: ENode.Op, args: List[ENode])
@@ -33,7 +33,7 @@ enum ENode:
3333
override def toString(): String =
3434
this match
3535
case Atom(tp) => typeToString(tp)
36-
case New(clazz) => s"new $clazz"
36+
case Constructor(constr) => s"new ${designatorToString(constr.lastKnownDenotation.owner)}"
3737
case Select(qual, member) => s"$qual.${designatorToString(member)}"
3838
case Apply(fn, args) => s"$fn(${args.mkString(", ")})"
3939
case OpApply(op, args) => s"(${args.mkString(op.operatorString())})"

‎compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ class QualifierSolver(using Context):
5353
case _ => ()
5454

5555
val egraph = EGraph(ctx)
56-
//println(s"tree implies $tree1 -> $tree2")
56+
//println(s"tree implies $tree1 -> $tree2")
5757
(egraph.toNode(QualifierEvaluator.evaluate(tree1)), egraph.toNode(QualifierEvaluator.evaluate(tree2))) match
5858
case (Some(node1), Some(node2)) =>
59-
//println(s"node implies $node1 -> $node2")
59+
//println(s"node implies $node1 -> $node2")
6060
egraph.merge(node1, egraph.trueNode)
6161
egraph.repair()
6262
egraph.equiv(node2, egraph.trueNode)

‎tests/neg-custom-args/qualified-types/adapt_neg.scala‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ def test: Unit =
1212
val v3: {v: Int with v == x + 1} = x + 2 // error
1313
val v4: {v: Int with v == f(x)} = g(x) // error
1414
val v5: {v: Int with v == g(x)} = f(x) // error
15-
//val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented
16-
//val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented
15+
val v6: {v: IntBox with v == IntBox(x)} = IntBox(x+1) // error
16+
val v7: {v: Box[Int] with v == Box(x)} = Box(x+1) // error
1717
val v8: {v: Int with v == x + f(x)} = x + g(x) // error
1818
val v9: {v: Int with v == x + g(x)} = x + f(x) // error
1919
val v10: {v: Int with v == f(x + 1)} = f(x + 2) // error
2020
val v11: {v: Int with v == g(x + 1)} = g(x + 2) // error
21-
//val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented
22-
//val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented
21+
val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x) // error
22+
val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x) // error
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
class Box[T](val x: T)
2+
3+
class BoxMutable[T](var x: T)
4+
5+
class Foo(val id: String):
6+
def this(x: Int) = this(x.toString)
7+
8+
class Person(val name: String, val age: Int)
9+
10+
class PersonCurried(val name: String)(val age: Int)
11+
12+
class PersonMutable(val name: String, var age: Int)
13+
14+
case class PersonCaseMutable(name: String, var age: Int)
15+
16+
case class PersonCaseSecondary(name: String, age: Int):
17+
def this(name: String) = this(name, 0)
18+
19+
case class PersonCaseEqualsOverriden(name: String, age: Int):
20+
override def equals(that: Object): Boolean = this eq that
21+
22+
def test: Unit =
23+
summon[{b: Box[Int] with b == Box(1)} =:= {b: Box[Int] with b == Box(1)}] // error
24+
25+
summon[{b: BoxMutable[Int] with b == BoxMutable(1)} =:= {b: BoxMutable[Int] with b == BoxMutable(1)}] // error
26+
// TODO(mbovel): restrict selection to stable members
27+
//summon[{b: BoxMutable[Int] with b.x == 3} =:= {b: BoxMutable[Int] with b.x == 3}]
28+
29+
summon[{f: Foo with f == Foo("hello")} =:= {f: Foo with f == Foo("hello")}] // error
30+
summon[{f: Foo with f == Foo(1)} =:= {f: Foo with f == Foo(1)}] // error
31+
summon[{s: String with Foo("hello").id == s} =:= {s: String with s == "hello"}] // error
32+
33+
summon[{p: Person with p == Person("Alice", 30)} =:= {p: Person with p == Person("Alice", 30)}] // error
34+
summon[{s: String with Person("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
35+
summon[{n: Int with Person("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
36+
37+
summon[{p: PersonCurried with p == PersonCurried("Alice")(30)} =:= {p: PersonCurried with p == PersonCurried("Alice")(30)}] // error
38+
summon[{s: String with PersonCurried("Alice")(30).name == s} =:= {s: String with s == "Alice"}] // error
39+
summon[{n: Int with PersonCurried("Alice")(30).age == n} =:= {n: Int with n == 30}] // error
40+
41+
summon[{p: PersonMutable with p == PersonMutable("Alice", 30)} =:= {p: PersonMutable with p == PersonMutable("Alice", 30)}] // error
42+
summon[{s: String with PersonMutable("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
43+
summon[{n: Int with PersonMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
44+
45+
summon[{n: Int with PersonCaseMutable("Alice", 30).age == n} =:= {n: Int with n == 30}] // error
46+
47+
summon[{s: String with new PersonCaseSecondary("Alice").name == s} =:= {s: String with s == "Alice"}] // error
48+
summon[{n: Int with new PersonCaseSecondary("Alice").age == n} =:= {n: Int with n == 0}] // error
49+
50+
summon[{p: PersonCaseEqualsOverriden with PersonCaseEqualsOverriden("Alice", 30) == p} =:= {p: PersonCaseEqualsOverriden with p == PersonCaseEqualsOverriden("Alice", 30)}] // error
51+
summon[{s: String with PersonCaseEqualsOverriden("Alice", 30).name == s} =:= {s: String with s == "Alice"}] // error
52+
summon[{n: Int with PersonCaseEqualsOverriden("Alice", 30).age == n} =:= {n: Int with n == 30}] // error

‎tests/pos-custom-args/qualified-types/adapt.scala‎

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ def f(x: Int): Int = ???
22
case class IntBox(x: Int)
33
case class Box[T](x: T)
44

5-
65
def f(x: Int, y: Int): {r: Int with r == x + y} = x + y
76

87
def test: Unit =
@@ -12,11 +11,11 @@ def test: Unit =
1211
val v1: {v: Int with v == x + 1} = x + 1
1312
val v2: {v: Int with v == f(x)} = f(x)
1413
val v3: {v: Int with v == g(x)} = g(x)
15-
//val v6: {v: Int with v == IntBox(x)} = IntBox(x) // Not implemented
16-
//val v7: {v: Int with v == Box(x)} = Box(x) // Not implemented
17-
val v4: {v: Int with v == x + f(x)} = x + f(x)
18-
val v5: {v: Int with v == x + g(x)} = x + g(x)
19-
val v6: {v: Int with v == f(x + 1)} = f(x + 1)
20-
val v7: {v: Int with v == g(x + 1)} = g(x + 1)
21-
//val v12: {v: Int with v == IntBox(x + 1)} = IntBox(x + 1) // Not implemented
22-
//val v13: {v: Int with v == Box(x + 1)} = Box(x + 1) // Not implemented
14+
val v4: {v: IntBox with v == IntBox(x)} = IntBox(x)
15+
val v5: {v: Box[Int] with v == Box(x)} = Box(x)
16+
val v6: {v: Int with v == x + f(x)} = x + f(x)
17+
val v7: {v: Int with v == x + g(x)} = x + g(x)
18+
val v8: {v: Int with v == f(x + 1)} = f(x + 1)
19+
val v9: {v: Int with v == g(x + 1)} = g(x + 1)
20+
val v12: {v: IntBox with v == IntBox(x + 1)} = IntBox(x + 1)
21+
val v13: {v: Box[Int] with v == Box(x + 1)} = Box(x + 1)

‎tests/pos-custom-args/qualified-types/sized_lists.scala‎

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
2-
31
def size(v: Vec): Int = ???
42
type Vec
53

6-
74
def vec(s: Int): {v: Vec with size(v) == s} = ???
85
def concat(v1: Vec, v2: Vec): {v: Vec with size(v) == size(v1) + size(v2)} = ???
96
def sum(v1: Vec, v2: Vec with size(v1) == size(v2)): {v: Vec with size(v) == size(v1)} = ???

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /