00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include <stdint.h>
00039 #include <stdlib.h>
00040 #include <iostream>
00041 #include <string>
00042 #include <sstream>
00043 #include <libsherpa/UExcept.hxx>
00044 #include <libsherpa/CVector.hxx>
00045 #include <assert.h>
00046 #include "UocInfo.hxx"
00047 #include "Options.hxx"
00048 #include "AST.hxx"
00049 #include "Type.hxx"
00050 #include "TypeScheme.hxx"
00051 #include "Typeclass.hxx"
00052 #include "inter-pass.hxx"
00053 #include "TypeEqInfer.hxx"
00054 #include "TypeInferCommon.hxx"
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072 bool
00073 TransClose(GCPtr<Constraints> cset)
00074 {
00075 size_t start = cset->pred->size();
00076 size_t pass_start=0;
00077 do {
00078 pass_start = cset->pred->size();
00079
00080 for(size_t i=0; i < cset->size(); i++) {
00081 GCPtr<Constraint> cti = cset->Pred(i)->getType();
00082 if(cti->kind != ty_subtype)
00083 continue;
00084
00085 for(size_t j=i+1; j < cset->size(); j++) {
00086 GCPtr<Constraint> ctj = cset->Pred(j)->getType();
00087 if(ctj->kind != ty_subtype)
00088 continue;
00089
00090
00091 if(cti->CompType(1) == ctj->CompType(0))
00092 addSubCst(cti->ast, cti->CompType(0),
00093 ctj->CompType(1), cset);
00094
00095
00096 if(cti->CompType(0) == ctj->CompType(1))
00097 addSubCst(ctj->ast, ctj->CompType(0),
00098 cti->CompType(1), cset);
00099 }
00100 }
00101 } while(cset->pred->size() > pass_start);
00102
00103 size_t norm = cset->pred->size();
00104 cset->normalize();
00105
00106 size_t final = cset->pred->size();
00107 return ((start != norm) || (norm != final));
00108 }
00109
00110
00111
00112 extern GCPtr<TvPrinter> debugTvp;
00113 void
00114 printCset(std::ostream& out, GCPtr<Constraints> cset)
00115 {
00116 if(cset->size()) {
00117 out << "{";
00118 for(size_t i=0; i < cset->size(); i++) {
00119 if(i > 0)
00120 out << ", ";
00121 out << cset->Pred(i)->asString(debugTvp);
00122 }
00123 out << "}";
00124 }
00125 }
00126
00127
00128 static bool
00129 typeError(std::ostream& errStream, GCPtr<Constraint> ct)
00130 {
00131 errStream << ct->ast->loc << ": Type Error."
00132 << " Unsatiafiable constraint: "
00133 << ct->asString();
00134
00135
00136 return false;
00137 }
00138
00139
00140 #define CMPSET(var, val) \
00141 do { \
00142 if(var != true) \
00143 var = val; \
00144 } while(0);
00145
00146
00147 bool
00148 EqUnify(std::ostream& errStream, GCPtr<Constraints> cset,
00149 GCPtr<Trail> trail)
00150 {
00151 bool cset_changed = false;
00152 bool errFree = true;
00153
00154 TransClose(cset);
00155 do {
00156 cset_changed = false;
00157
00158 for(size_t i=0; i < cset->size(); i++) {
00159 bool unified_in_this_iteration = true;
00160
00161
00162
00163
00164 GCPtr<Constraint> ct = cset->Pred(i)->getType();
00165 if(ct->flags & CT_REMOVE)
00166 continue;
00167
00168 switch(ct->kind) {
00169 case ty_subtype:
00170 {
00171 GCPtr<Type> lhs = ct->CompType(0)->getType();
00172 GCPtr<Type> rhs = ct->CompType(1)->getType();
00173
00174
00175 if(lhs == rhs) {
00176 ct->flags |= CT_REMOVE;
00177 break;
00178 }
00179
00180 if((lhs->kind == rhs->kind) && lhs->isBaseConstType()) {
00181 ct->flags |= CT_REMOVE;
00182 break;
00183 }
00184
00185
00186 if((lhs->kind == ty_mutable) &&
00187 (lhs->CompType(0)->getType() == rhs)) {
00188 ct->flags |= CT_REMOVE;
00189 break;
00190 }
00191
00192
00193 if(lhs->kind == ty_tvar) {
00194 GCPtr<Type> reverse = new Constraint(ty_subtype, ct->ast,
00195 rhs, lhs);
00196 if(cset->contains(reverse)) {
00197 ct->flags |= CT_REMOVE;
00198 trail->subst(lhs, rhs);
00199 break;
00200 }
00201 }
00202
00203
00204
00205
00206
00207
00208 if((lhs->kind == ty_tvar) && rhs->isMaxMutable()) {
00209 ct->flags |= CT_REMOVE;
00210 trail->subst(lhs, rhs);
00211 break;
00212 }
00213
00214
00215 if((rhs->kind == ty_tvar) && (lhs->kind != ty_tvar) &&
00216 lhs->isMinMutable()) {
00217 ct->flags |= CT_REMOVE;
00218 trail->subst(rhs, lhs);
00219 break;
00220 }
00221
00222
00223 if((lhs->kind == ty_fn) && (rhs->kind == ty_fn)) {
00224 GCPtr<Type> arg1 = lhs->CompType(0)->getType();
00225 GCPtr<Type> arg2 = rhs->CompType(0)->getType();
00226 GCPtr<Type> ret1 = lhs->CompType(1)->getType();
00227 GCPtr<Type> ret2 = rhs->CompType(1)->getType();
00228
00229 if(arg1->components->size() == arg2->components->size()) {
00230 ct->flags |= CT_REMOVE;
00231
00232 for(size_t a=0; a < arg1->components->size(); a++)
00233 addEqCst(ct->ast, arg1->CompType(a),
00234 arg2->CompType(a), cset);
00235
00236 addEqCst(ct->ast, ret1, ret2, cset);
00237 break;
00238 }
00239 }
00240
00241
00242 if((lhs->kind == ty_ref) && (rhs->kind == ty_ref)) {
00243 ct->flags |= CT_REMOVE;
00244
00245 addEqCst(ct->ast, lhs->CompType(0),
00246 rhs->CompType(0), cset);
00247 break;
00248 }
00249
00250
00251 if((lhs->kind == ty_mutable) && (rhs->kind == ty_mutable)) {
00252 ct->flags |= CT_REMOVE;
00253
00254 addSubCst(ct->ast, lhs->CompType(0),
00255 rhs->CompType(0), cset);
00256 break;
00257 }
00258
00259
00260 if((lhs->kind == ty_mutable) && (rhs->kind != ty_tvar)) {
00261 assert(rhs->kind != ty_mutable);
00262 ct->flags |= CT_REMOVE;
00263
00264 addSubCst(ct->ast, lhs->CompType(0), rhs, cset);
00265 break;
00266 }
00267
00268
00269
00270
00271
00272 unified_in_this_iteration = false;
00273 break;
00274 }
00275
00276 case ty_pcst:
00277 {
00278 GCPtr<Type> k = ct->CompType(0)->getType();
00279 GCPtr<Type> gen = ct->CompType(1)->getType();
00280 GCPtr<Type> ins = ct->CompType(2)->getType();
00281
00282
00283 if(k == Type::Kmono) {
00284 ct->flags |= CT_REMOVE;
00285 addEqCst(ct->ast, gen, ins, cset);
00286 break;
00287 }
00288
00289
00290 if (k == Type::Kpoly && ins->isDeepImmutable()) {
00291 ct->flags |= CT_REMOVE;
00292 break;
00293 }
00294
00295 if(k->kind == ty_kvar && ins->isDeepMut()) {
00296 trail->subst(k, Type::Kmono);
00297 break;
00298 }
00299
00300
00301 if (k->kind == ty_kvar) {
00302 GCPtr<Constraints> newC = new Constraints();
00303 bool more_found = false;
00304 for(size_t c=0; c < cset->size(); c++) {
00305 GCPtr<Constraint> newCt = cset->Pred(c)->getType();
00306
00307 if(newCt->flags & CT_REMOVE)
00308 continue;
00309
00310 if(newCt->kind == ty_pcst && newCt->CompType(0) == k) {
00311 if(newCt != ct)
00312 more_found = true;
00313
00314 addEqCst(ct->ast, ct->CompType(2),
00315 newCt->CompType(2), newC);
00316 }
00317 else
00318 newC->addPred(newCt);
00319 }
00320
00321 if(more_found) {
00322 GCPtr<Trail> tr = new Trail();
00323 bool unifies = EqUnify(errStream, newC, tr);
00324 tr->rollBack();
00325
00326 if(!unifies) {
00327 trail->subst(k, Type::Kpoly);
00328 break;
00329 }
00330 }
00331 }
00332
00333 unified_in_this_iteration = false;
00334 break;
00335 }
00336
00337 default:
00338 {
00339 assert(false);
00340 break;
00341 }
00342 }
00343
00344 CMPSET(cset_changed, unified_in_this_iteration);
00345 }
00346
00347 if(cset_changed) {
00348 GCPtr<Constraints> newC = new Constraints();
00349 for(size_t c=0; c < cset->size(); c++) {
00350 GCPtr<Constraint> ct = cset->Pred(c)->getType();
00351
00352 if((ct->flags & CT_REMOVE) == 0)
00353 newC->addPred(ct);
00354 }
00355
00356 cset->pred = newC->pred;
00357
00358
00359
00360
00361
00362 TransClose(cset);
00363
00364
00365
00366
00367 }
00368
00369 } while(cset_changed);
00370
00371
00372 for(size_t i=0; i < cset->size(); i++) {
00373 GCPtr<Constraint> ct = cset->Pred(i)->getType();
00374
00375
00376 switch(ct->kind) {
00377 case ty_subtype:
00378 {
00379 GCPtr<Type> lhs = ct->CompType(0)->getType();
00380 GCPtr<Type> rhs = ct->CompType(1)->getType();
00381
00382
00383 if(lhs->kind == ty_tvar || rhs->kind == ty_tvar)
00384 break;
00385
00386 errFree = typeError(errStream, ct);
00387 break;
00388 }
00389 case ty_pcst:
00390 {
00391 GCPtr<Type> k = ct->CompType(0)->getType();
00392 GCPtr<Type> gen = ct->CompType(1)->getType();
00393 GCPtr<Type> ins = ct->CompType(2)->getType();
00394
00395
00396 if (k->kind == ty_kvar)
00397 break;
00398
00399
00400 if (k == Type::Kpoly && ins->isDeepImmut())
00401 break;
00402
00403 errFree = typeError(errStream, ct);
00404 break;
00405 }
00406
00407 default:
00408 {
00409 assert(false);
00410 break;
00411 }
00412 }
00413 }
00414
00415 return errFree;
00416 }