Skip to content

Commit c8b6b32

Browse files
authored
Declare that definition_validator does not mutate (#9731)
* Declare that definition_validator does not mutate * Need to change the Transform methods * Rename `tree` → `topTree`
1 parent 01f9ff7 commit c8b6b32

File tree

7 files changed

+58
-22
lines changed

7 files changed

+58
-22
lines changed

ast/treemap/treemap.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,29 @@ template <class FUNC, class CTX, TreeMapKind Kind, TreeMapDepthKind DepthKind> c
340340
CALL_POST(ClassDef);
341341
}
342342

343+
template <TreeMapKind K = Kind>
344+
typename std::enable_if_t<K == TreeMapKind::ConstWalk, return_type> mapClassDefDirect(const ClassDef &v, CTX ctx) {
345+
Funcs::template CALL_MEMBER_preTransformClassDef<FUNC>::call(func, ctx, v);
346+
347+
// We intentionally do not walk v->ancestors nor v->singletonAncestors.
348+
//
349+
// These lists used to be guaranteed to be simple trees (only constant literals) by desugar,
350+
// but that was later relaxed. In places where walking ancestors is required, instead define
351+
// your `preTransformClassDef` method to contain this:
352+
//
353+
// for (auto &ancestor : klass.ancestors) {
354+
// ancestor = ast::TreeMap::apply(ctx, *this, std::move(ancestor))
355+
// }
356+
//
357+
// and that will have the same effect, without having to retroactively change all TreeMaps.
358+
359+
for (auto &def : v.rhs) {
360+
CALL_MAP(def, ctx.withOwner(v.symbol).withFile(ctx.file));
361+
}
362+
363+
Funcs::template CALL_MEMBER_postTransformClassDef<FUNC>::call(func, ctx, v);
364+
}
365+
343366
return_type mapMethodDef(arg_type v, CTX ctx) {
344367
CALL_PRE(MethodDef);
345368

@@ -756,6 +779,18 @@ class ConstTreeWalk {
756779
throw exception.reported;
757780
}
758781
}
782+
template <typename CTX, typename FUNC> static void apply(CTX ctx, FUNC &func, const ClassDef &to) {
783+
TreeMapper<FUNC, CTX, TreeMapKind::ConstWalk, TreeMapDepthKind::Full> walker(func);
784+
try {
785+
walker.mapClassDefDirect(to, ctx);
786+
} catch (ReportedRubyException &exception) {
787+
Exception::failInFuzzer();
788+
if (auto e = ctx.beginError(exception.onLoc, core::errors::Internal::InternalError)) {
789+
e.setHeader("Failed to process tree (backtrace is above)");
790+
}
791+
throw exception.reported;
792+
}
793+
}
759794
};
760795

761796
class ShallowMap {

core/sig_finder/sig_finder.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,10 @@ optional<SigFinder::Result> SigFinder::findSignature(core::Context ctx, const as
133133
ast::ConstTreeWalk::apply(ctx, sigFinder, tree);
134134
return move(sigFinder.result_);
135135
}
136+
optional<SigFinder::Result> SigFinder::findSignature(core::Context ctx, const ast::ClassDef &tree, core::Loc queryLoc) {
137+
SigFinder sigFinder(queryLoc);
138+
ast::ConstTreeWalk::apply(ctx, sigFinder, tree);
139+
return move(sigFinder.result_);
140+
}
136141

137142
} // namespace sorbet::sig_finder

core/sig_finder/sig_finder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SigFinder {
3939
void preTransformSend(core::Context ctx, const ast::Send &tree);
4040

4141
static std::optional<Result> findSignature(core::Context ctx, const ast::ExpressionPtr &tree, core::Loc queryLoc);
42+
static std::optional<Result> findSignature(core::Context ctx, const ast::ClassDef &tree, core::Loc queryLoc);
4243
};
4344

4445
} // namespace sorbet::sig_finder

definition_validator/validator.cc

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ void validateFinalAncestorHelper(core::Context ctx, const core::ClassOrModuleRef
773773
}
774774
}
775775

776-
void validateFinalMethodHelper(core::Context ctx, const core::ClassOrModuleRef klass, const ast::ExpressionPtr &tree,
776+
void validateFinalMethodHelper(core::Context ctx, const core::ClassOrModuleRef klass, const ast::ClassDef &classDef,
777777
const core::ClassOrModuleRef errMsgClass) {
778778
if (!klass.data(ctx)->flags.isFinal) {
779779
return;
@@ -795,7 +795,7 @@ void validateFinalMethodHelper(core::Context ctx, const core::ClassOrModuleRef k
795795
e.setHeader("`{}` was declared as final but its method `{}` was not declared as final",
796796
errMsgClass.show(ctx), sym.name(ctx).show(ctx));
797797
auto queryLoc = defLoc.copyWithZeroLength();
798-
auto parsedSig = sig_finder::SigFinder::findSignature(ctx, tree, queryLoc);
798+
auto parsedSig = sig_finder::SigFinder::findSignature(ctx, classDef, queryLoc);
799799

800800
if (parsedSig.has_value() && parsedSig->origSend.funLoc.exists()) {
801801
auto funLoc = ctx.locAt(parsedSig->origSend.funLoc);
@@ -805,8 +805,7 @@ void validateFinalMethodHelper(core::Context ctx, const core::ClassOrModuleRef k
805805
}
806806
}
807807

808-
void validateFinal(core::Context ctx, const core::ClassOrModuleRef klass, const ast::ExpressionPtr &tree) {
809-
const ast::ClassDef &classDef = ast::cast_tree_nonnull<ast::ClassDef>(tree);
808+
void validateFinal(core::Context ctx, const core::ClassOrModuleRef klass, const ast::ClassDef &classDef) {
810809
const auto superClass = klass.data(ctx)->superClass();
811810
if (superClass.exists()) {
812811
auto superClassData = superClass.data(ctx);
@@ -831,10 +830,10 @@ void validateFinal(core::Context ctx, const core::ClassOrModuleRef klass, const
831830
}
832831
}
833832
validateFinalAncestorHelper(ctx, klass, classDef, klass, "included");
834-
validateFinalMethodHelper(ctx, klass, tree, klass);
833+
validateFinalMethodHelper(ctx, klass, classDef, klass);
835834
const auto singleton = klass.data(ctx)->lookupSingletonClass(ctx);
836835
validateFinalAncestorHelper(ctx, singleton, classDef, klass, "extended");
837-
validateFinalMethodHelper(ctx, singleton, tree, klass);
836+
validateFinalMethodHelper(ctx, singleton, classDef, klass);
838837
}
839838

840839
// Ignore RBI files for the purpose of checking sealed (unless there are no other files).
@@ -1072,7 +1071,7 @@ void validateRequiredAncestors(core::Context ctx, const core::ClassOrModuleRef s
10721071

10731072
class ValidateWalk {
10741073
public:
1075-
ValidateWalk(const ast::ExpressionPtr &tree) : tree(tree) {}
1074+
ValidateWalk(const ast::ExpressionPtr &topTree) : topTree(topTree) {}
10761075

10771076
private:
10781077
// NOTE: A better representation for our AST might be to store a method's signature(s) within
@@ -1088,7 +1087,7 @@ class ValidateWalk {
10881087
// find the signature for a given method (incurring a full walk of the tree) every time we want
10891088
// to find a single method. That means instead of ValidateWalk only doing one walk of the tree,
10901089
// it does `num_errors + 1` walks, which can be slow in the case of many errors.
1091-
const ast::ExpressionPtr &tree;
1090+
const ast::ExpressionPtr &topTree;
10921091

10931092
UnorderedMap<core::ClassOrModuleRef, vector<core::MethodRef>> abstractCache;
10941093

@@ -1304,8 +1303,7 @@ class ValidateWalk {
13041303
}
13051304

13061305
public:
1307-
void preTransformClassDef(core::Context ctx, const ast::ExpressionPtr &tree) {
1308-
auto &classDef = ast::cast_tree_nonnull<ast::ClassDef>(tree);
1306+
void preTransformClassDef(core::Context ctx, const ast::ClassDef &classDef) {
13091307
auto sym = classDef.symbol;
13101308
auto singleton = sym.data(ctx)->lookupSingletonClass(ctx);
13111309
validateTStructNotGrandparent(ctx, sym);
@@ -1319,7 +1317,7 @@ class ValidateWalk {
13191317
}
13201318
}
13211319
validateAbstract(ctx, singleton, classDef);
1322-
validateFinal(ctx, sym, tree);
1320+
validateFinal(ctx, sym, classDef);
13231321
validateSealed(ctx, sym, classDef);
13241322
validateSuperClass(ctx, sym, classDef);
13251323

@@ -1328,8 +1326,7 @@ class ValidateWalk {
13281326
}
13291327
}
13301328

1331-
void preTransformMethodDef(core::Context ctx, const ast::ExpressionPtr &tree) {
1332-
auto &methodDef = ast::cast_tree_nonnull<ast::MethodDef>(tree);
1329+
void preTransformMethodDef(core::Context ctx, const ast::MethodDef &methodDef) {
13331330
auto methodData = methodDef.symbol.data(ctx);
13341331

13351332
if (methodData->locs().empty()) {
@@ -1353,11 +1350,10 @@ class ValidateWalk {
13531350
// See the comment in `VarianceValidator::validateMethod` for an explanation of why we don't
13541351
// need to check types on instance variables.
13551352

1356-
validateOverriding(ctx, this->tree, methodDef);
1353+
validateOverriding(ctx, this->topTree, methodDef);
13571354
}
13581355

1359-
void postTransformSend(core::Context ctx, const ast::ExpressionPtr &tree) {
1360-
auto &send = ast::cast_tree_nonnull<ast::Send>(tree);
1356+
void postTransformSend(core::Context ctx, const ast::Send &send) {
13611357
if (send.fun != core::Names::new_()) {
13621358
return;
13631359
}
@@ -1391,12 +1387,11 @@ class ValidateWalk {
13911387
};
13921388
} // namespace
13931389

1394-
ast::ParsedFile runOne(core::Context ctx, ast::ParsedFile tree) {
1390+
void runOne(core::Context ctx, const ast::ParsedFile &tree) {
13951391
Timer timeit(ctx.state.tracer(), "validateSymbols", {{"file", string(tree.file.data(ctx).path())}});
13961392

13971393
ValidateWalk validate(tree.tree);
1398-
ast::TreeWalk::apply(ctx, validate, tree.tree);
1399-
return tree;
1394+
ast::ConstTreeWalk::apply(ctx, validate, tree.tree);
14001395
}
14011396

14021397
} // namespace sorbet::definition_validator

definition_validator/validator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace sorbet::definition_validator {
88

9-
ast::ParsedFile runOne(core::Context ctx, ast::ParsedFile tree);
9+
void runOne(core::Context ctx, const ast::ParsedFile &tree);
1010

1111
} // namespace sorbet::definition_validator
1212

main/pipeline/pipeline.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1464,7 +1464,7 @@ void typecheckOne(core::Context ctx, ast::ParsedFile resolved, const options::Op
14641464

14651465
resolved = class_flatten::runOne(ctx, move(resolved));
14661466

1467-
resolved = definition_validator::runOne(ctx, std::move(resolved));
1467+
definition_validator::runOne(ctx, resolved);
14681468

14691469
if (opts.print.FlattenTree.enabled || opts.print.AST.enabled) {
14701470
opts.print.FlattenTree.fmt("{}\n", resolved.tree.toString(ctx));

test/pipeline_test_runner.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ TEST_CASE("PerPhaseTest") { // NOLINT
560560
core::Context ctx(*gs, core::Symbols::root(), file);
561561
resolvedTree = class_flatten::runOne(ctx, move(resolvedTree));
562562

563-
resolvedTree = definition_validator::runOne(ctx, move(resolvedTree));
563+
definition_validator::runOne(ctx, resolvedTree);
564564
handler.drainErrors(*gs);
565565

566566
handler.addObserved(*gs, "flatten-tree", [&]() { return resolvedTree.tree.toString(*gs); });

0 commit comments

Comments
 (0)