LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 const APInt *C, *C2;
1914 // zext (C + A)<nsw> -> (sext(C) + sext(A))<nsw> if zext (C + A)<nsw> >=s 0.
1915 // Currently the non-negative check is done manually, as isKnownNonNegative
1916 // is too expensive.
1917 if (SA->hasNoSignedWrap() &&
1919 m_scev_SMax(m_scev_APInt(C2), m_SCEV()))) &&
1920 C->isNegative() && !C->isMinSignedValue() && C2->sge(C->abs())) {
1921 assert(isKnownNonNegative(SA) && "incorrectly determined non-negative");
1922 return getAddExpr(getSignExtendExpr(SA->getOperand(0), Ty, Depth + 1),
1923 getSignExtendExpr(SA->getOperand(1), Ty, Depth + 1),
1924 SCEV::FlagNSW, Depth + 1);
1925 }
1926
1927 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1928 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1929 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1930 //
1931 // Often address arithmetics contain expressions like
1932 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1933 // This transformation is useful while proving that such expressions are
1934 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1935 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1936 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1937 if (D != 0) {
1938 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1939 const SCEV *SResidual =
1941 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1942 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1943 Depth + 1);
1944 }
1945 }
1946 }
1947
1948 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1949 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1950 if (SM->hasNoUnsignedWrap()) {
1951 // If the multiply does not unsign overflow then we can, by definition,
1952 // commute the zero extension with the multiply operation.
1954 for (SCEVUse Op : SM->operands())
1955 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1956 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1957 }
1958
1959 // zext(2^K * (trunc X to iN)) to iM ->
1960 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1961 //
1962 // Proof:
1963 //
1964 // zext(2^K * (trunc X to iN)) to iM
1965 // = zext((trunc X to iN) << K) to iM
1966 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1967 // (because shl removes the top K bits)
1968 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1969 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1970 //
1971 const APInt *C;
1972 const SCEV *TruncRHS;
1973 if (match(SM,
1974 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1975 C->isPowerOf2()) {
1976 int NewTruncBits =
1977 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1978 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1979 return getMulExpr(
1980 getZeroExtendExpr(SM->getOperand(0), Ty),
1981 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1982 SCEV::FlagNUW, Depth + 1);
1983 }
1984 }
1985
1986 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1987 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1990 SmallVector<SCEVUse, 4> Operands;
1991 for (SCEVUse Operand : MinMax->operands())
1992 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1994 return getUMinExpr(Operands);
1995 return getUMaxExpr(Operands);
1996 }
1997
1998 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
2000 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
2001 SmallVector<SCEVUse, 4> Operands;
2002 for (SCEVUse Operand : MinMax->operands())
2003 Operands.push_back(getZeroExtendExpr(Operand, Ty));
2004 return getUMinExpr(Operands, /*Sequential*/ true);
2005 }
2006
2007 // The cast wasn't folded; create an explicit cast node.
2008 // Recompute the insert position, as it may have been invalidated.
2009 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2010 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
2011 Op, Ty);
2012 UniqueSCEVs.InsertNode(S, IP);
2013 S->computeAndSetCanonical(*this);
2014 registerUser(S, Op);
2015 return S;
2016}
2017
2018const SCEV *
2020 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2021 "This is not an extending conversion!");
2022 assert(isSCEVable(Ty) &&
2023 "This is not a conversion to a SCEVable type!");
2024 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2025 Ty = getEffectiveSCEVType(Ty);
2026
2027 FoldID ID(scSignExtend, Op, Ty);
2028 if (const SCEV *S = FoldCache.lookup(ID))
2029 return S;
2030
2031 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2033 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2034 return S;
2035}
2036
2038 unsigned Depth) {
2039 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2040 "This is not an extending conversion!");
2041 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2042 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2043 Ty = getEffectiveSCEVType(Ty);
2044
2045 // Fold if the operand is constant.
2046 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2047 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2048
2049 // sext(sext(x)) --> sext(x)
2051 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2052
2053 // sext(zext(x)) --> zext(x)
2055 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2056
2057 // Before doing any expensive analysis, check to see if we've already
2058 // computed a SCEV for this Op and Ty.
2060 ID.AddInteger(scSignExtend);
2061 ID.AddPointer(Op);
2062 ID.AddPointer(Ty);
2063 void *IP = nullptr;
2064 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2065 // Limit recursion depth.
2066 if (Depth > MaxCastDepth) {
2067 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2068 Op, Ty);
2069 UniqueSCEVs.InsertNode(S, IP);
2070 S->computeAndSetCanonical(*this);
2071 registerUser(S, Op);
2072 return S;
2073 }
2074
2075 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2077 // It's possible the bits taken off by the truncate were all sign bits. If
2078 // so, we should be able to simplify this further.
2079 const SCEV *X = ST->getOperand();
2081 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2082 unsigned NewBits = getTypeSizeInBits(Ty);
2083 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2084 CR.sextOrTrunc(NewBits)))
2085 return getTruncateOrSignExtend(X, Ty, Depth);
2086 }
2087
2088 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2089 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2090 if (SA->hasNoSignedWrap()) {
2091 // If the addition does not sign overflow then we can, by definition,
2092 // commute the sign extension with the addition operation.
2094 for (SCEVUse Op : SA->operands())
2095 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2096 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2097 }
2098
2099 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2100 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2101 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2102 //
2103 // For instance, this will bring two seemingly different expressions:
2104 // 1 + sext(5 + 20 * %x + 24 * %y) and
2105 // sext(6 + 20 * %x + 24 * %y)
2106 // to the same form:
2107 // 2 + sext(4 + 20 * %x + 24 * %y)
2108 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2109 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2110 if (D != 0) {
2111 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2112 const SCEV *SResidual =
2114 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2115 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2116 Depth + 1);
2117 }
2118 }
2119 }
2120 // If the input value is a chrec scev, and we can prove that the value
2121 // did not overflow the old, smaller, value, we can sign extend all of the
2122 // operands (often constants). This allows analysis of something like
2123 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2125 if (AR->isAffine()) {
2126 const SCEV *Start = AR->getStart();
2127 const SCEV *Step = AR->getStepRecurrence(*this);
2128 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2129 const Loop *L = AR->getLoop();
2130
2131 // If we have special knowledge that this addrec won't overflow,
2132 // we don't need to do any further analysis.
2133 if (AR->hasNoSignedWrap()) {
2134 Start =
2136 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2137 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2138 }
2139
2140 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2141 // Note that this serves two purposes: It filters out loops that are
2142 // simply not analyzable, and it covers the case where this code is
2143 // being called from within backedge-taken count analysis, such that
2144 // attempting to ask for the backedge-taken count would likely result
2145 // in infinite recursion. In the later case, the analysis code will
2146 // cope with a conservative value, and it will take care to purge
2147 // that value once it has finished.
2148 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2149 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2150 // Manually compute the final value for AR, checking for
2151 // overflow.
2152
2153 // Check whether the backedge-taken count can be losslessly casted to
2154 // the addrec's type. The count is always unsigned.
2155 const SCEV *CastedMaxBECount =
2156 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2157 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2158 CastedMaxBECount, MaxBECount->getType(), Depth);
2159 if (MaxBECount == RecastedMaxBECount) {
2160 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2161 // Check whether Start+Step*MaxBECount has no signed overflow.
2162 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2164 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2166 Depth + 1),
2167 WideTy, Depth + 1);
2168 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2169 const SCEV *WideMaxBECount =
2170 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2171 const SCEV *OperandExtendedAdd =
2172 getAddExpr(WideStart,
2173 getMulExpr(WideMaxBECount,
2174 getSignExtendExpr(Step, WideTy, Depth + 1),
2177 if (SAdd == OperandExtendedAdd) {
2178 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2179 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2180 // Return the expression with the addrec on the outside.
2181 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2182 Depth + 1);
2183 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2184 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2185 }
2186 // Similar to above, only this time treat the step value as unsigned.
2187 // This covers loops that count up with an unsigned step.
2188 OperandExtendedAdd =
2189 getAddExpr(WideStart,
2190 getMulExpr(WideMaxBECount,
2191 getZeroExtendExpr(Step, WideTy, Depth + 1),
2194 if (SAdd == OperandExtendedAdd) {
2195 // If AR wraps around then
2196 //
2197 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2198 // => SAdd != OperandExtendedAdd
2199 //
2200 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2201 // (SAdd == OperandExtendedAdd => AR is NW)
2202
2203 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2204
2205 // Return the expression with the addrec on the outside.
2206 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2207 Depth + 1);
2208 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2209 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2210 }
2211 }
2212 }
2213
2214 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2215 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2216 if (AR->hasNoSignedWrap()) {
2217 // Same as nsw case above - duplicated here to avoid a compile time
2218 // issue. It's not clear that the order of checks does matter, but
2219 // it's one of two issue possible causes for a change which was
2220 // reverted. Be conservative for the moment.
2221 Start =
2223 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2224 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2225 }
2226
2227 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2228 // if D + (C - D + Step * n) could be proven to not signed wrap
2229 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2230 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2231 const APInt &C = SC->getAPInt();
2232 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2233 if (D != 0) {
2234 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2235 const SCEV *SResidual =
2236 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2237 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2238 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2239 Depth + 1);
2240 }
2241 }
2242
2243 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2244 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2245 Start =
2247 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2248 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2249 }
2250 }
2251
2252 // If the input value is provably positive and we could not simplify
2253 // away the sext build a zext instead.
2255 return getZeroExtendExpr(Op, Ty, Depth + 1);
2256
2257 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2258 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2261 SmallVector<SCEVUse, 4> Operands;
2262 for (SCEVUse Operand : MinMax->operands())
2263 Operands.push_back(getSignExtendExpr(Operand, Ty));
2265 return getSMinExpr(Operands);
2266 return getSMaxExpr(Operands);
2267 }
2268
2269 // The cast wasn't folded; create an explicit cast node.
2270 // Recompute the insert position, as it may have been invalidated.
2271 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2272 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2273 Op, Ty);
2274 UniqueSCEVs.InsertNode(S, IP);
2275 S->computeAndSetCanonical(*this);
2276 registerUser(S, Op);
2277 return S;
2278}
2279
2281 Type *Ty) {
2282 switch (Kind) {
2283 case scTruncate:
2284 return getTruncateExpr(Op, Ty);
2285 case scZeroExtend:
2286 return getZeroExtendExpr(Op, Ty);
2287 case scSignExtend:
2288 return getSignExtendExpr(Op, Ty);
2289 case scPtrToInt:
2290 return getPtrToIntExpr(Op, Ty);
2291 default:
2292 llvm_unreachable("Not a SCEV cast expression!");
2293 }
2294}
2295
2296/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2297/// unspecified bits out to the given type.
2299 Type *Ty) {
2300 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2301 "This is not an extending conversion!");
2302 assert(isSCEVable(Ty) &&
2303 "This is not a conversion to a SCEVable type!");
2304 Ty = getEffectiveSCEVType(Ty);
2305
2306 // Sign-extend negative constants.
2307 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2308 if (SC->getAPInt().isNegative())
2309 return getSignExtendExpr(Op, Ty);
2310
2311 // Peel off a truncate cast.
2313 const SCEV *NewOp = T->getOperand();
2314 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2315 return getAnyExtendExpr(NewOp, Ty);
2316 return getTruncateOrNoop(NewOp, Ty);
2317 }
2318
2319 // Next try a zext cast. If the cast is folded, use it.
2320 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2321 if (!isa<SCEVZeroExtendExpr>(ZExt))
2322 return ZExt;
2323
2324 // Next try a sext cast. If the cast is folded, use it.
2325 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2326 if (!isa<SCEVSignExtendExpr>(SExt))
2327 return SExt;
2328
2329 // Force the cast to be folded into the operands of an addrec.
2330 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2332 for (const SCEV *Op : AR->operands())
2333 Ops.push_back(getAnyExtendExpr(Op, Ty));
2334 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2335 }
2336
2337 // If the expression is obviously signed, use the sext cast value.
2338 if (isa<SCEVSMaxExpr>(Op))
2339 return SExt;
2340
2341 // Absent any other information, use the zext cast value.
2342 return ZExt;
2343}
2344
2345/// Process the given Ops list, which is a list of operands to be added under
2346/// the given scale, update the given map. This is a helper function for
2347/// getAddRecExpr. As an example of what it does, given a sequence of operands
2348/// that would form an add expression like this:
2349///
2350/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2351///
2352/// where A and B are constants, update the map with these values:
2353///
2354/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2355///
2356/// and add 13 + A*B*29 to AccumulatedConstant.
2357/// This will allow getAddRecExpr to produce this:
2358///
2359/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2360///
2361/// This form often exposes folding opportunities that are hidden in
2362/// the original operand list.
2363///
2364/// Return true iff it appears that any interesting folding opportunities
2365/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2366/// the common case where no interesting opportunities are present, and
2367/// is also used as a check to avoid infinite recursion.
2370 APInt &AccumulatedConstant,
2372 const APInt &Scale,
2373 ScalarEvolution &SE) {
2374 bool Interesting = false;
2375
2376 // Iterate over the add operands. They are sorted, with constants first.
2377 unsigned i = 0;
2378 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2379 ++i;
2380 // Pull a buried constant out to the outside.
2381 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2382 Interesting = true;
2383 AccumulatedConstant += Scale * C->getAPInt();
2384 }
2385
2386 // Next comes everything else. We're especially interested in multiplies
2387 // here, but they're in the middle, so just visit the rest with one loop.
2388 for (; i != Ops.size(); ++i) {
2390 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2391 APInt NewScale =
2392 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2393 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2394 // A multiplication of a constant with another add; recurse.
2395 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2396 Interesting |= CollectAddOperandsWithScales(
2397 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2398 } else {
2399 // A multiplication of a constant with some other value. Update
2400 // the map.
2401 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2402 const SCEV *Key = SE.getMulExpr(MulOps);
2403 auto Pair = M.insert({Key, NewScale});
2404 if (Pair.second) {
2405 NewOps.push_back(Pair.first->first);
2406 } else {
2407 Pair.first->second += NewScale;
2408 // The map already had an entry for this value, which may indicate
2409 // a folding opportunity.
2410 Interesting = true;
2411 }
2412 }
2413 } else {
2414 // An ordinary operand. Update the map.
2415 auto Pair = M.insert({Ops[i], Scale});
2416 if (Pair.second) {
2417 NewOps.push_back(Pair.first->first);
2418 } else {
2419 Pair.first->second += Scale;
2420 // The map already had an entry for this value, which may indicate
2421 // a folding opportunity.
2422 Interesting = true;
2423 }
2424 }
2425 }
2426
2427 return Interesting;
2428}
2429
2431 const SCEV *LHS, const SCEV *RHS,
2432 const Instruction *CtxI) {
2434 unsigned);
2435 switch (BinOp) {
2436 default:
2437 llvm_unreachable("Unsupported binary op");
2438 case Instruction::Add:
2440 break;
2441 case Instruction::Sub:
2443 break;
2444 case Instruction::Mul:
2446 break;
2447 }
2448
2449 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2452
2453 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2454 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2455 auto *WideTy =
2456 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2457
2458 const SCEV *A = (this->*Extension)(
2459 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2460 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2461 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2462 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2463 if (A == B)
2464 return true;
2465 // Can we use context to prove the fact we need?
2466 if (!CtxI)
2467 return false;
2468 // TODO: Support mul.
2469 if (BinOp == Instruction::Mul)
2470 return false;
2471 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2472 // TODO: Lift this limitation.
2473 if (!RHSC)
2474 return false;
2475 APInt C = RHSC->getAPInt();
2476 unsigned NumBits = C.getBitWidth();
2477 bool IsSub = (BinOp == Instruction::Sub);
2478 bool IsNegativeConst = (Signed && C.isNegative());
2479 // Compute the direction and magnitude by which we need to check overflow.
2480 bool OverflowDown = IsSub ^ IsNegativeConst;
2481 APInt Magnitude = C;
2482 if (IsNegativeConst) {
2483 if (C == APInt::getSignedMinValue(NumBits))
2484 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2485 // want to deal with that.
2486 return false;
2487 Magnitude = -C;
2488 }
2489
2491 if (OverflowDown) {
2492 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2493 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2494 : APInt::getMinValue(NumBits);
2495 APInt Limit = Min + Magnitude;
2496 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2497 } else {
2498 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2499 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2500 : APInt::getMaxValue(NumBits);
2501 APInt Limit = Max - Magnitude;
2502 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2503 }
2504}
2505
2506std::optional<SCEV::NoWrapFlags>
2508 const OverflowingBinaryOperator *OBO) {
2509 // It cannot be done any better.
2510 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2511 return std::nullopt;
2512
2513 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2514
2515 if (OBO->hasNoUnsignedWrap())
2517 if (OBO->hasNoSignedWrap())
2519
2520 bool Deduced = false;
2521
2522 if (OBO->getOpcode() != Instruction::Add &&
2523 OBO->getOpcode() != Instruction::Sub &&
2524 OBO->getOpcode() != Instruction::Mul)
2525 return std::nullopt;
2526
2527 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2528 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2529
2530 const Instruction *CtxI =
2532 if (!OBO->hasNoUnsignedWrap() &&
2534 /* Signed */ false, LHS, RHS, CtxI)) {
2536 Deduced = true;
2537 }
2538
2539 if (!OBO->hasNoSignedWrap() &&
2541 /* Signed */ true, LHS, RHS, CtxI)) {
2543 Deduced = true;
2544 }
2545
2546 if (Deduced)
2547 return Flags;
2548 return std::nullopt;
2549}
2550
2551// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2552// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2553// can't-overflow flags for the operation if possible.
2557 SCEV::NoWrapFlags Flags) {
2558 using namespace std::placeholders;
2559
2560 using OBO = OverflowingBinaryOperator;
2561
2562 bool CanAnalyze =
2564 (void)CanAnalyze;
2565 assert(CanAnalyze && "don't call from other places!");
2566
2567 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2568 SCEV::NoWrapFlags SignOrUnsignWrap =
2569 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2570
2571 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2572 auto IsKnownNonNegative = [&](SCEVUse U) {
2573 return SE->isKnownNonNegative(U);
2574 };
2575
2576 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2577 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2578
2579 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2580
2581 if (SignOrUnsignWrap != SignOrUnsignMask &&
2582 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2583 isa<SCEVConstant>(Ops[0])) {
2584
2585 auto Opcode = [&] {
2586 switch (Type) {
2587 case scAddExpr:
2588 return Instruction::Add;
2589 case scMulExpr:
2590 return Instruction::Mul;
2591 default:
2592 llvm_unreachable("Unexpected SCEV op.");
2593 }
2594 }();
2595
2596 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2597
2598 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2599 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2601 Opcode, C, OBO::NoSignedWrap);
2602 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2604 }
2605
2606 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2607 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2609 Opcode, C, OBO::NoUnsignedWrap);
2610 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2612 }
2613 }
2614
2615 // <0,+,nonnegative><nw> is also nuw
2616 // TODO: Add corresponding nsw case
2618 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2619 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2621
2622 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2624 Ops.size() == 2) {
2625 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2626 if (UDiv->getOperand(1) == Ops[1])
2628 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2629 if (UDiv->getOperand(1) == Ops[0])
2631 }
2632
2633 return Flags;
2634}
2635
2637 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2638}
2639
2640/// Get a canonical add expression, or something simpler if possible.
2642 SCEV::NoWrapFlags OrigFlags,
2643 unsigned Depth) {
2644 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2645 "only nuw or nsw allowed");
2646 assert(!Ops.empty() && "Cannot get empty add!");
2647 if (Ops.size() == 1) return Ops[0];
2648#ifndef NDEBUG
2649 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2650 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2651 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2652 "SCEVAddExpr operand types don't match!");
2653 unsigned NumPtrs = count_if(
2654 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2655 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2656#endif
2657
2658 const SCEV *Folded = constantFoldAndGroupOps(
2659 *this, LI, DT, Ops,
2660 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2661 [](const APInt &C) { return C.isZero(); }, // identity
2662 [](const APInt &C) { return false; }); // absorber
2663 if (Folded)
2664 return Folded;
2665
2666 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2667
2668 // Delay expensive flag strengthening until necessary.
2669 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2670 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2671 };
2672
2673 // Limit recursion calls depth.
2675 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2676
2677 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2678 // Don't strengthen flags if we have no new information.
2679 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2680 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2681 Add->setNoWrapFlags(ComputeFlags(Ops));
2682 return S;
2683 }
2684
2685 // Okay, check to see if the same value occurs in the operand list more than
2686 // once. If so, merge them together into an multiply expression. Since we
2687 // sorted the list, these values are required to be adjacent.
2688 Type *Ty = Ops[0]->getType();
2689 bool FoundMatch = false;
2690 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2691 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2692 // Scan ahead to count how many equal operands there are.
2693 unsigned Count = 2;
2694 while (i+Count != e && Ops[i+Count] == Ops[i])
2695 ++Count;
2696 // Merge the values into a multiply.
2697 SCEVUse Scale = getConstant(Ty, Count);
2698 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2699 if (Ops.size() == Count)
2700 return Mul;
2701 Ops[i] = Mul;
2702 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2703 --i; e -= Count - 1;
2704 FoundMatch = true;
2705 }
2706 if (FoundMatch)
2707 return getAddExpr(Ops, OrigFlags, Depth + 1);
2708
2709 // Check for truncates. If all the operands are truncated from the same
2710 // type, see if factoring out the truncate would permit the result to be
2711 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2712 // if the contents of the resulting outer trunc fold to something simple.
2713 auto FindTruncSrcType = [&]() -> Type * {
2714 // We're ultimately looking to fold an addrec of truncs and muls of only
2715 // constants and truncs, so if we find any other types of SCEV
2716 // as operands of the addrec then we bail and return nullptr here.
2717 // Otherwise, we return the type of the operand of a trunc that we find.
2718 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2719 return T->getOperand()->getType();
2720 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2721 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2722 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2723 return T->getOperand()->getType();
2724 }
2725 return nullptr;
2726 };
2727 if (auto *SrcType = FindTruncSrcType()) {
2728 SmallVector<SCEVUse, 8> LargeOps;
2729 bool Ok = true;
2730 // Check all the operands to see if they can be represented in the
2731 // source type of the truncate.
2732 for (const SCEV *Op : Ops) {
2734 if (T->getOperand()->getType() != SrcType) {
2735 Ok = false;
2736 break;
2737 }
2738 LargeOps.push_back(T->getOperand());
2739 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2740 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2741 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2742 SmallVector<SCEVUse, 8> LargeMulOps;
2743 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2744 if (const SCEVTruncateExpr *T =
2745 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2746 if (T->getOperand()->getType() != SrcType) {
2747 Ok = false;
2748 break;
2749 }
2750 LargeMulOps.push_back(T->getOperand());
2751 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2752 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2753 } else {
2754 Ok = false;
2755 break;
2756 }
2757 }
2758 if (Ok)
2759 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2760 } else {
2761 Ok = false;
2762 break;
2763 }
2764 }
2765 if (Ok) {
2766 // Evaluate the expression in the larger type.
2767 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2768 // If it folds to something simple, use it. Otherwise, don't.
2769 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2770 return getTruncateExpr(Fold, Ty);
2771 }
2772 }
2773
2774 if (Ops.size() == 2) {
2775 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2776 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2777 // C1).
2778 const SCEV *A = Ops[0];
2779 const SCEV *B = Ops[1];
2780 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2781 auto *C = dyn_cast<SCEVConstant>(A);
2782 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2783 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2784 auto C2 = C->getAPInt();
2785 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2786
2787 APInt ConstAdd = C1 + C2;
2788 auto AddFlags = AddExpr->getNoWrapFlags();
2789 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2791 ConstAdd.ule(C1)) {
2792 PreservedFlags =
2794 }
2795
2796 // Adding a constant with the same sign and small magnitude is NSW, if the
2797 // original AddExpr was NSW.
2799 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2800 ConstAdd.abs().ule(C1.abs())) {
2801 PreservedFlags =
2803 }
2804
2805 if (PreservedFlags != SCEV::FlagAnyWrap) {
2806 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2807 NewOps[0] = getConstant(ConstAdd);
2808 return getAddExpr(NewOps, PreservedFlags);
2809 }
2810 }
2811
2812 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2813 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2814 const SCEVAddExpr *InnerAdd;
2815 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2816 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2817 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2818 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2819 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2821 SCEV::FlagNUW)) {
2822 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2823 }
2824 }
2825 }
2826
2827 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2828 const SCEV *Y;
2829 if (Ops.size() == 2 &&
2830 match(Ops[0],
2832 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2833 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2834
2835 // Skip past any other cast SCEVs.
2836 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2837 ++Idx;
2838
2839 // If there are add operands they would be next.
2840 if (Idx < Ops.size()) {
2841 bool DeletedAdd = false;
2842 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2843 // common NUW flag for expression after inlining. Other flags cannot be
2844 // preserved, because they may depend on the original order of operations.
2845 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2846 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2847 if (Ops.size() > AddOpsInlineThreshold ||
2848 Add->getNumOperands() > AddOpsInlineThreshold)
2849 break;
2850 // If we have an add, expand the add operands onto the end of the operands
2851 // list.
2852 Ops.erase(Ops.begin()+Idx);
2853 append_range(Ops, Add->operands());
2854 DeletedAdd = true;
2855 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2856 }
2857
2858 // If we deleted at least one add, we added operands to the end of the list,
2859 // and they are not necessarily sorted. Recurse to resort and resimplify
2860 // any operands we just acquired.
2861 if (DeletedAdd)
2862 return getAddExpr(Ops, CommonFlags, Depth + 1);
2863 }
2864
2865 // Skip over the add expression until we get to a multiply.
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2867 ++Idx;
2868
2869 // Check to see if there are any folding opportunities present with
2870 // operands multiplied by constant values.
2871 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2875 APInt AccumulatedConstant(BitWidth, 0);
2876 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2877 Ops, APInt(BitWidth, 1), *this)) {
2878 struct APIntCompare {
2879 bool operator()(const APInt &LHS, const APInt &RHS) const {
2880 return LHS.ult(RHS);
2881 }
2882 };
2883
2884 // Some interesting folding opportunity is present, so its worthwhile to
2885 // re-generate the operands list. Group the operands by constant scale,
2886 // to avoid multiplying by the same constant scale multiple times.
2887 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2888 for (const SCEV *NewOp : NewOps)
2889 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2890 // Re-generate the operands list.
2891 Ops.clear();
2892 if (AccumulatedConstant != 0)
2893 Ops.push_back(getConstant(AccumulatedConstant));
2894 for (auto &MulOp : MulOpLists) {
2895 if (MulOp.first == 1) {
2896 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2897 } else if (MulOp.first != 0) {
2898 Ops.push_back(getMulExpr(
2899 getConstant(MulOp.first),
2900 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2901 SCEV::FlagAnyWrap, Depth + 1));
2902 }
2903 }
2904 if (Ops.empty())
2905 return getZero(Ty);
2906 if (Ops.size() == 1)
2907 return Ops[0];
2908 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2909 }
2910 }
2911
2912 // Given a SCEVMulExpr and an operand index, return the product of all
2913 // operands except the one at OpIdx.
2914 auto StripFactor = [&](const SCEVMulExpr *M, unsigned OpIdx) -> SCEVUse {
2915 if (M->getNumOperands() == 2)
2916 return M->getOperand(OpIdx == 0);
2917 SmallVector<SCEVUse, 4> Remaining(M->operands().take_front(OpIdx));
2918 append_range(Remaining, M->operands().drop_front(OpIdx + 1));
2919 return getMulExpr(Remaining, SCEV::FlagAnyWrap, Depth + 1);
2920 };
2921
2922 // If we are adding something to a multiply expression, make sure the
2923 // something is not already an operand of the multiply. If so, merge it into
2924 // the multiply.
2925 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2926 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2927 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2928 // Scan all terms to find every occurrence of common factor MulOpSCEV
2929 // and fold them in one shot:
2930 // A1*X + A2*X + ... + An*X --> X * (A1 + A2 + ... + An)
2931 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2932 if (isa<SCEVConstant>(MulOpSCEV))
2933 continue;
2934
2935 // Cofactors: 1 for bare addends matching MulOpSCEV, or the
2936 // remaining product for multiply terms containing MulOpSCEV.
2937 SmallVector<SCEVUse, 4> Cofactors;
2938 SmallVector<unsigned, 4> DeadIndices;
2939 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) {
2940 if (MulOpSCEV == Ops[AddOp]) {
2941 // W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2942 Cofactors.push_back(getOne(Ty));
2943 DeadIndices.push_back(AddOp);
2944 continue;
2945 }
2946
2947 if (AddOp <= Idx || !isa<SCEVMulExpr>(Ops[AddOp]))
2948 continue;
2949
2950 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[AddOp]);
2951 for (unsigned OMulOp = 0, OE = OtherMul->getNumOperands(); OMulOp != OE;
2952 ++OMulOp) {
2953 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2954 // (A*B*C) + (A*D*E) --> A * (B*C + D*E)
2955 Cofactors.push_back(StripFactor(OtherMul, OMulOp));
2956 DeadIndices.push_back(AddOp);
2957 break;
2958 }
2959 }
2960 }
2961
2962 // Fold all collected cofactors with the anchor multiply's cofactor:
2963 // MulOpSCEV * (Cofactor_1 + ... + Cofactor_n + AnchorCofactor)
2964 if (!Cofactors.empty()) {
2965 Cofactors.push_back(StripFactor(Mul, MulOp));
2966
2967 SCEVUse InnerSum = getAddExpr(Cofactors, SCEV::FlagAnyWrap, Depth + 1);
2968 SCEVUse OuterMul =
2969 getMulExpr(MulOpSCEV, InnerSum, SCEV::FlagAnyWrap, Depth + 1);
2970
2971 // DeadIndices does not include Idx (the anchor), hence +1.
2972 if (Ops.size() == DeadIndices.size() + 1)
2973 return OuterMul;
2974
2975 // Erase Ops[Idx] first, then erase DeadIndices in reverse order.
2976 // The -1 adjustment accounts for the shift from removing Idx;
2977 // reverse order means each erasure only shifts later positions,
2978 // which have already been processed.
2979 Ops.erase(Ops.begin() + Idx);
2980 for (unsigned Dead : reverse(DeadIndices))
2981 Ops.erase(Ops.begin() + (Dead > Idx ? Dead - 1 : Dead));
2982
2983 Ops.push_back(OuterMul);
2984 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2985 }
2986 }
2987 }
2988
2989 // If there are any add recurrences in the operands list, see if any other
2990 // added values are loop invariant. If so, we can fold them into the
2991 // recurrence.
2992 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2993 ++Idx;
2994
2995 // Scan over all recurrences, trying to fold loop invariants into them.
2996 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2997 // Scan all of the other operands to this add and add them to the vector if
2998 // they are loop invariant w.r.t. the recurrence.
3000 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3001 const Loop *AddRecLoop = AddRec->getLoop();
3002 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3003 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3004 LIOps.push_back(Ops[i]);
3005 Ops.erase(Ops.begin()+i);
3006 --i; --e;
3007 }
3008
3009 // If we found some loop invariants, fold them into the recurrence.
3010 if (!LIOps.empty()) {
3011 // Compute nowrap flags for the addition of the loop-invariant ops and
3012 // the addrec. Temporarily push it as an operand for that purpose. These
3013 // flags are valid in the scope of the addrec only.
3014 LIOps.push_back(AddRec);
3015 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
3016 LIOps.pop_back();
3017
3018 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3019 LIOps.push_back(AddRec->getStart());
3020
3021 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3022
3023 // It is not in general safe to propagate flags valid on an add within
3024 // the addrec scope to one outside it. We must prove that the inner
3025 // scope is guaranteed to execute if the outer one does to be able to
3026 // safely propagate. We know the program is undefined if poison is
3027 // produced on the inner scoped addrec. We also know that *for this use*
3028 // the outer scoped add can't overflow (because of the flags we just
3029 // computed for the inner scoped add) without the program being undefined.
3030 // Proving that entry to the outer scope neccesitates entry to the inner
3031 // scope, thus proves the program undefined if the flags would be violated
3032 // in the outer scope.
3033 SCEV::NoWrapFlags AddFlags = Flags;
3034 if (AddFlags != SCEV::FlagAnyWrap) {
3035 auto *DefI = getDefiningScopeBound(LIOps);
3036 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3037 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3038 AddFlags = SCEV::FlagAnyWrap;
3039 }
3040 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3041
3042 // Build the new addrec. Propagate the NUW and NSW flags if both the
3043 // outer add and the inner addrec are guaranteed to have no overflow.
3044 // Always propagate NW.
3045 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3046 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3047
3048 // If all of the other operands were loop invariant, we are done.
3049 if (Ops.size() == 1) return NewRec;
3050
3051 // Otherwise, add the folded AddRec by the non-invariant parts.
3052 for (unsigned i = 0;; ++i)
3053 if (Ops[i] == AddRec) {
3054 Ops[i] = NewRec;
3055 break;
3056 }
3057 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3058 }
3059
3060 // Okay, if there weren't any loop invariants to be folded, check to see if
3061 // there are multiple AddRec's with the same loop induction variable being
3062 // added together. If so, we can fold them.
3063 for (unsigned OtherIdx = Idx+1;
3064 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3065 ++OtherIdx) {
3066 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3067 // so that the 1st found AddRecExpr is dominated by all others.
3068 assert(DT.dominates(
3069 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3070 AddRec->getLoop()->getHeader()) &&
3071 "AddRecExprs are not sorted in reverse dominance order?");
3072 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3073 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3074 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3075 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3076 ++OtherIdx) {
3077 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3078 if (OtherAddRec->getLoop() == AddRecLoop) {
3079 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3080 i != e; ++i) {
3081 if (i >= AddRecOps.size()) {
3082 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3083 break;
3084 }
3085 AddRecOps[i] =
3086 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3088 }
3089 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3090 }
3091 }
3092 // Step size has changed, so we cannot guarantee no self-wraparound.
3093 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3094 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3095 }
3096 }
3097
3098 // Otherwise couldn't fold anything into this recurrence. Move onto the
3099 // next one.
3100 }
3101
3102 // Okay, it looks like we really DO need an add expr. Check to see if we
3103 // already have one, otherwise create a new one.
3104 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3105}
3106
3107const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3108 SCEV::NoWrapFlags Flags) {
3110 ID.AddInteger(scAddExpr);
3111 for (const SCEV *Op : Ops)
3112 ID.AddPointer(Op);
3113 void *IP = nullptr;
3114 SCEVAddExpr *S =
3115 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3116 if (!S) {
3117 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3119 S = new (SCEVAllocator)
3120 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3121 UniqueSCEVs.InsertNode(S, IP);
3122 S->computeAndSetCanonical(*this);
3123 registerUser(S, Ops);
3124 }
3125 S->setNoWrapFlags(Flags);
3126 return S;
3127}
3128
3129const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3130 const Loop *L,
3131 SCEV::NoWrapFlags Flags) {
3132 FoldingSetNodeID ID;
3133 ID.AddInteger(scAddRecExpr);
3134 for (const SCEV *Op : Ops)
3135 ID.AddPointer(Op);
3136 ID.AddPointer(L);
3137 void *IP = nullptr;
3138 SCEVAddRecExpr *S =
3139 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3140 if (!S) {
3141 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3143 S = new (SCEVAllocator)
3144 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3145 UniqueSCEVs.InsertNode(S, IP);
3146 S->computeAndSetCanonical(*this);
3147 LoopUsers[L].push_back(S);
3148 registerUser(S, Ops);
3149 }
3150 setNoWrapFlags(S, Flags);
3151 return S;
3152}
3153
3154const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3155 SCEV::NoWrapFlags Flags) {
3156 FoldingSetNodeID ID;
3157 ID.AddInteger(scMulExpr);
3158 for (const SCEV *Op : Ops)
3159 ID.AddPointer(Op);
3160 void *IP = nullptr;
3161 SCEVMulExpr *S =
3162 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3163 if (!S) {
3164 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3166 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3167 O, Ops.size());
3168 UniqueSCEVs.InsertNode(S, IP);
3169 S->computeAndSetCanonical(*this);
3170 registerUser(S, Ops);
3171 }
3172 S->setNoWrapFlags(Flags);
3173 return S;
3174}
3175
3176static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3177 uint64_t k = i*j;
3178 if (j > 1 && k / j != i) Overflow = true;
3179 return k;
3180}
3181
3182/// Compute the result of "n choose k", the binomial coefficient. If an
3183/// intermediate computation overflows, Overflow will be set and the return will
3184/// be garbage. Overflow is not cleared on absence of overflow.
3185static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3186 // We use the multiplicative formula:
3187 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3188 // At each iteration, we take the n-th term of the numeral and divide by the
3189 // (k-n)th term of the denominator. This division will always produce an
3190 // integral result, and helps reduce the chance of overflow in the
3191 // intermediate computations. However, we can still overflow even when the
3192 // final result would fit.
3193
3194 if (n == 0 || n == k) return 1;
3195 if (k > n) return 0;
3196
3197 if (k > n/2)
3198 k = n-k;
3199
3200 uint64_t r = 1;
3201 for (uint64_t i = 1; i <= k; ++i) {
3202 r = umul_ov(r, n-(i-1), Overflow);
3203 r /= i;
3204 }
3205 return r;
3206}
3207
3208/// Determine if any of the operands in this SCEV are a constant or if
3209/// any of the add or multiply expressions in this SCEV contain a constant.
3210static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3211 struct FindConstantInAddMulChain {
3212 bool FoundConstant = false;
3213
3214 bool follow(const SCEV *S) {
3215 FoundConstant |= isa<SCEVConstant>(S);
3216 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3217 }
3218
3219 bool isDone() const {
3220 return FoundConstant;
3221 }
3222 };
3223
3224 FindConstantInAddMulChain F;
3226 ST.visitAll(StartExpr);
3227 return F.FoundConstant;
3228}
3229
3230/// Get a canonical multiply expression, or something simpler if possible.
3232 SCEV::NoWrapFlags OrigFlags,
3233 unsigned Depth) {
3234 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3235 "only nuw or nsw allowed");
3236 assert(!Ops.empty() && "Cannot get empty mul!");
3237 if (Ops.size() == 1) return Ops[0];
3238#ifndef NDEBUG
3239 Type *ETy = Ops[0]->getType();
3240 assert(!ETy->isPointerTy());
3241 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3242 assert(Ops[i]->getType() == ETy &&
3243 "SCEVMulExpr operand types don't match!");
3244#endif
3245
3246 const SCEV *Folded = constantFoldAndGroupOps(
3247 *this, LI, DT, Ops,
3248 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3249 [](const APInt &C) { return C.isOne(); }, // identity
3250 [](const APInt &C) { return C.isZero(); }); // absorber
3251 if (Folded)
3252 return Folded;
3253
3254 // Delay expensive flag strengthening until necessary.
3255 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3256 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3257 };
3258
3259 // Limit recursion calls depth.
3261 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3262
3263 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3264 // Don't strengthen flags if we have no new information.
3265 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3266 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3267 Mul->setNoWrapFlags(ComputeFlags(Ops));
3268 return S;
3269 }
3270
3271 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3272 if (Ops.size() == 2) {
3273 // C1*(C2+V) -> C1*C2 + C1*V
3274 // If any of Add's ops are Adds or Muls with a constant, apply this
3275 // transformation as well.
3276 //
3277 // TODO: There are some cases where this transformation is not
3278 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3279 // this transformation should be narrowed down.
3280 const SCEV *Op0, *Op1;
3281 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3283 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3284 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3285 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3286 }
3287
3288 if (Ops[0]->isAllOnesValue()) {
3289 // If we have a mul by -1 of an add, try distributing the -1 among the
3290 // add operands.
3291 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3293 bool AnyFolded = false;
3294 for (const SCEV *AddOp : Add->operands()) {
3295 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3297 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3298 NewOps.push_back(Mul);
3299 }
3300 if (AnyFolded)
3301 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3302 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3303 // Negation preserves a recurrence's no self-wrap property.
3304 SmallVector<SCEVUse, 4> Operands;
3305 for (const SCEV *AddRecOp : AddRec->operands())
3306 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3307 SCEV::FlagAnyWrap, Depth + 1));
3308 // Let M be the minimum representable signed value. AddRec with nsw
3309 // multiplied by -1 can have signed overflow if and only if it takes a
3310 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3311 // maximum signed value. In all other cases signed overflow is
3312 // impossible.
3313 auto FlagsMask = SCEV::FlagNW;
3314 if (AddRec->hasNoSignedWrap()) {
3315 auto MinInt =
3316 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3317 if (getSignedRangeMin(AddRec) != MinInt)
3318 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3319 }
3320 return getAddRecExpr(Operands, AddRec->getLoop(),
3321 AddRec->getNoWrapFlags(FlagsMask));
3322 }
3323 }
3324
3325 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3326 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3327 const SCEVAddExpr *InnerAdd;
3328 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3329 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3330 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3331 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3332 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3334 SCEV::FlagNUW)) {
3335 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3336 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3337 };
3338 }
3339
3340 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3341 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3342 // of C1, fold to (D /u (C2 /u C1)).
3343 const SCEV *D;
3344 APInt C1V = LHSC->getAPInt();
3345 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3346 // as -1 * 1, as it won't enable additional folds.
3347 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3348 C1V = C1V.abs();
3349 const SCEVConstant *C2;
3350 if (C1V.isPowerOf2() &&
3352 C2->getAPInt().isPowerOf2() &&
3353 C1V.logBase2() <= getMinTrailingZeros(D)) {
3354 const SCEV *NewMul = nullptr;
3355 if (C1V.uge(C2->getAPInt())) {
3356 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3357 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3358 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3359 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3360 }
3361 if (NewMul)
3362 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3363 }
3364 }
3365 }
3366
3367 // Skip over the add expression until we get to a multiply.
3368 unsigned Idx = 0;
3369 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3370 ++Idx;
3371
3372 // If there are mul operands inline them all into this expression.
3373 if (Idx < Ops.size()) {
3374 bool DeletedMul = false;
3375 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3376 if (Ops.size() > MulOpsInlineThreshold)
3377 break;
3378 // If we have an mul, expand the mul operands onto the end of the
3379 // operands list.
3380 Ops.erase(Ops.begin()+Idx);
3381 append_range(Ops, Mul->operands());
3382 DeletedMul = true;
3383 }
3384
3385 // If we deleted at least one mul, we added operands to the end of the
3386 // list, and they are not necessarily sorted. Recurse to resort and
3387 // resimplify any operands we just acquired.
3388 if (DeletedMul)
3389 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3390 }
3391
3392 // If there are any add recurrences in the operands list, see if any other
3393 // added values are loop invariant. If so, we can fold them into the
3394 // recurrence.
3395 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3396 ++Idx;
3397
3398 // Scan over all recurrences, trying to fold loop invariants into them.
3399 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3400 // Scan all of the other operands to this mul and add them to the vector
3401 // if they are loop invariant w.r.t. the recurrence.
3403 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3404 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3405 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3406 LIOps.push_back(Ops[i]);
3407 Ops.erase(Ops.begin()+i);
3408 --i; --e;
3409 }
3410
3411 // If we found some loop invariants, fold them into the recurrence.
3412 if (!LIOps.empty()) {
3413 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3415 NewOps.reserve(AddRec->getNumOperands());
3416 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3417
3418 // If both the mul and addrec are nuw, we can preserve nuw.
3419 // If both the mul and addrec are nsw, we can only preserve nsw if either
3420 // a) they are also nuw, or
3421 // b) all multiplications of addrec operands with scale are nsw.
3422 SCEV::NoWrapFlags Flags =
3423 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3424
3425 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3426 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3427 SCEV::FlagAnyWrap, Depth + 1));
3428
3429 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3431 Instruction::Mul, getSignedRange(Scale),
3433 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3434 Flags = clearFlags(Flags, SCEV::FlagNSW);
3435 }
3436 }
3437
3438 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3439
3440 // If all of the other operands were loop invariant, we are done.
3441 if (Ops.size() == 1) return NewRec;
3442
3443 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3444 for (unsigned i = 0;; ++i)
3445 if (Ops[i] == AddRec) {
3446 Ops[i] = NewRec;
3447 break;
3448 }
3449 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3450 }
3451
3452 // Okay, if there weren't any loop invariants to be folded, check to see
3453 // if there are multiple AddRec's with the same loop induction variable
3454 // being multiplied together. If so, we can fold them.
3455
3456 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3457 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3458 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3459 // ]]],+,...up to x=2n}.
3460 // Note that the arguments to choose() are always integers with values
3461 // known at compile time, never SCEV objects.
3462 //
3463 // The implementation avoids pointless extra computations when the two
3464 // addrec's are of different length (mathematically, it's equivalent to
3465 // an infinite stream of zeros on the right).
3466 bool OpsModified = false;
3467 for (unsigned OtherIdx = Idx+1;
3468 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3469 ++OtherIdx) {
3470 const SCEVAddRecExpr *OtherAddRec =
3471 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3472 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3473 continue;
3474
3475 // Limit max number of arguments to avoid creation of unreasonably big
3476 // SCEVAddRecs with very complex operands.
3477 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3478 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3479 continue;
3480
3481 bool Overflow = false;
3482 Type *Ty = AddRec->getType();
3483 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3484 SmallVector<SCEVUse, 7> AddRecOps;
3485 for (int x = 0, xe = AddRec->getNumOperands() +
3486 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3488 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3489 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3490 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3491 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3492 z < ze && !Overflow; ++z) {
3493 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3494 uint64_t Coeff;
3495 if (LargerThan64Bits)
3496 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3497 else
3498 Coeff = Coeff1*Coeff2;
3499 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3500 const SCEV *Term1 = AddRec->getOperand(y-z);
3501 const SCEV *Term2 = OtherAddRec->getOperand(z);
3502 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3503 SCEV::FlagAnyWrap, Depth + 1));
3504 }
3505 }
3506 if (SumOps.empty())
3507 SumOps.push_back(getZero(Ty));
3508 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3509 }
3510 if (!Overflow) {
3511 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3513 if (Ops.size() == 2) return NewAddRec;
3514 Ops[Idx] = NewAddRec;
3515 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3516 OpsModified = true;
3517 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3518 if (!AddRec)
3519 break;
3520 }
3521 }
3522 if (OpsModified)
3523 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3524
3525 // Otherwise couldn't fold anything into this recurrence. Move onto the
3526 // next one.
3527 }
3528
3529 // Okay, it looks like we really DO need an mul expr. Check to see if we
3530 // already have one, otherwise create a new one.
3531 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3532}
3533
3534/// Represents an unsigned remainder expression based on unsigned division.
3536 assert(getEffectiveSCEVType(LHS->getType()) ==
3537 getEffectiveSCEVType(RHS->getType()) &&
3538 "SCEVURemExpr operand types don't match!");
3539
3540 // Short-circuit easy cases
3541 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3542 // If constant is one, the result is trivial
3543 if (RHSC->getValue()->isOne())
3544 return getZero(LHS->getType()); // X urem 1 --> 0
3545
3546 // If constant is a power of two, fold into a zext(trunc(LHS)).
3547 if (RHSC->getAPInt().isPowerOf2()) {
3548 Type *FullTy = LHS->getType();
3549 Type *TruncTy =
3550 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3551 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3552 }
3553 }
3554
3555 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3556 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3557 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3558 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3559}
3560
3561/// Get a canonical unsigned division expression, or something simpler if
3562/// possible.
3564 assert(!LHS->getType()->isPointerTy() &&
3565 "SCEVUDivExpr operand can't be pointer!");
3566 assert(LHS->getType() == RHS->getType() &&
3567 "SCEVUDivExpr operand types don't match!");
3568
3570 ID.AddInteger(scUDivExpr);
3571 ID.AddPointer(LHS);
3572 ID.AddPointer(RHS);
3573 void *IP = nullptr;
3574 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3575 return S;
3576
3577 // 0 udiv Y == 0
3578 if (match(LHS, m_scev_Zero()))
3579 return LHS;
3580
3581 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3582 if (RHSC->getValue()->isOne())
3583 return LHS; // X udiv 1 --> x
3584 // If the denominator is zero, the result of the udiv is undefined. Don't
3585 // try to analyze it, because the resolution chosen here may differ from
3586 // the resolution chosen in other parts of the compiler.
3587 if (!RHSC->getValue()->isZero()) {
3588 // Determine if the division can be folded into the operands of
3589 // its operands.
3590 // TODO: Generalize this to non-constants by using known-bits information.
3591 Type *Ty = LHS->getType();
3592 unsigned LZ = RHSC->getAPInt().countl_zero();
3593 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3594 // For non-power-of-two values, effectively round the value up to the
3595 // nearest power of two.
3596 if (!RHSC->getAPInt().isPowerOf2())
3597 ++MaxShiftAmt;
3598 IntegerType *ExtTy =
3599 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3600 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3601 if (const SCEVConstant *Step =
3602 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3603 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3604 const APInt &StepInt = Step->getAPInt();
3605 const APInt &DivInt = RHSC->getAPInt();
3606 if (!StepInt.urem(DivInt) &&
3607 getZeroExtendExpr(AR, ExtTy) ==
3608 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3609 getZeroExtendExpr(Step, ExtTy),
3610 AR->getLoop(), SCEV::FlagAnyWrap)) {
3611 SmallVector<SCEVUse, 4> Operands;
3612 for (const SCEV *Op : AR->operands())
3613 Operands.push_back(getUDivExpr(Op, RHS));
3614 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3615 }
3616 /// Get a canonical UDivExpr for a recurrence.
3617 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3618 const APInt *StartRem;
3619 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3620 m_scev_APInt(StartRem))) {
3621 bool NoWrap =
3622 getZeroExtendExpr(AR, ExtTy) ==
3623 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3624 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3626
3627 // With N <= C and both N, C as powers-of-2, the transformation
3628 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3629 // if wrapping occurs, as the division results remain equivalent for
3630 // all offsets in [[(X - X%N), X).
3631 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3632 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3633 // Only fold if the subtraction can be folded in the start
3634 // expression.
3635 const SCEV *NewStart =
3636 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3637 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3638 !isa<SCEVAddExpr>(NewStart)) {
3639 const SCEV *NewLHS =
3640 getAddRecExpr(NewStart, Step, AR->getLoop(),
3641 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3642 if (LHS != NewLHS) {
3643 LHS = NewLHS;
3644
3645 // Reset the ID to include the new LHS, and check if it is
3646 // already cached.
3647 ID.clear();
3648 ID.AddInteger(scUDivExpr);
3649 ID.AddPointer(LHS);
3650 ID.AddPointer(RHS);
3651 IP = nullptr;
3652 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3653 return S;
3654 }
3655 }
3656 }
3657 }
3658 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3659 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3660 SmallVector<SCEVUse, 4> Operands;
3661 for (const SCEV *Op : M->operands())
3662 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3663 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) {
3664 // Find an operand that's safely divisible.
3665 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3666 const SCEV *Op = M->getOperand(i);
3667 const SCEV *Div = getUDivExpr(Op, RHSC);
3668 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3669 Operands = SmallVector<SCEVUse, 4>(M->operands());
3670 Operands[i] = Div;
3671 return getMulExpr(Operands);
3672 }
3673 }
3674
3675 // Even if it's not divisible, try to remove a common factor.
3676 if (const auto *LHSC = dyn_cast<SCEVConstant>(M->getOperand(0))) {
3677 APInt Factor = APIntOps::GreatestCommonDivisor(LHSC->getAPInt(),
3678 RHSC->getAPInt());
3679 if (!Factor.isIntN(1)) {
3680 SmallVector<SCEVUse, 2> NewOperands;
3681 NewOperands.push_back(getConstant(LHSC->getAPInt().udiv(Factor)));
3682 append_range(NewOperands, M->operands().drop_front());
3683 const SCEV *NewMul = getMulExpr(NewOperands);
3684 return getUDivExpr(NewMul,
3685 getConstant(RHSC->getAPInt().udiv(Factor)));
3686 }
3687 }
3688 }
3689 }
3690
3691 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3692 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3693 if (auto *DivisorConstant =
3694 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3695 bool Overflow = false;
3696 APInt NewRHS =
3697 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3698 if (Overflow) {
3699 return getConstant(RHSC->getType(), 0, false);
3700 }
3701 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3702 }
3703 }
3704
3705 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3706 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3707 SmallVector<SCEVUse, 4> Operands;
3708 for (const SCEV *Op : A->operands())
3709 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3710 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3711 Operands.clear();
3712 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3713 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3714 if (isa<SCEVUDivExpr>(Op) ||
3715 getMulExpr(Op, RHS) != A->getOperand(i))
3716 break;
3717 Operands.push_back(Op);
3718 }
3719 if (Operands.size() == A->getNumOperands())
3720 return getAddExpr(Operands);
3721 }
3722 }
3723
3724 // ((N - M) + (M * A)) / N --> ((N - 1) + (M * A)) / N
3725 // This is an idiom for rounding A up to the next multiple of N, where A
3726 // is aready known to be a multiple of M. In this case, instcombine can
3727 // see that some low bits of the added constant are unused, so can clear
3728 // them, but we want to canonicalise to set the low bits. This makes the
3729 // pattern easier to match, without needing to check for known bits in
3730 // A*M.
3731 const APInt &N = RHSC->getAPInt();
3732 const APInt *NMinusM, *M;
3733 const SCEV *A;
3734 if (match(LHS, m_scev_Add(m_scev_APInt(NMinusM),
3735 m_scev_Mul(m_scev_APInt(M), m_SCEV(A))))) {
3736 if (N.isPowerOf2() && M->isPowerOf2() && M->ult(N) &&
3737 *NMinusM == N - *M) {
3738 return getUDivExpr(
3740 RHS);
3741 }
3742 }
3743
3744 // Fold if both operands are constant.
3745 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3746 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3747 }
3748 }
3749
3750 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3751 const APInt *NegC, *C;
3752 if (match(LHS,
3755 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3756 return getZero(LHS->getType());
3757
3758 // (%a * %b)<nuw> / %b -> %a
3759 const auto *Mul = dyn_cast<SCEVMulExpr>(LHS);
3760 if (Mul && Mul->hasNoUnsignedWrap()) {
3761 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3762 if (Mul->getOperand(i) == RHS) {
3763 SmallVector<SCEVUse, 2> Operands;
3764 append_range(Operands, Mul->operands().take_front(i));
3765 append_range(Operands, Mul->operands().drop_front(i + 1));
3766 return getMulExpr(Operands);
3767 }
3768 }
3769 }
3770
3771 // TODO: Generalize to handle any common factors.
3772 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3773 const SCEV *NewLHS, *NewRHS;
3774 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3775 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3776 return getUDivExpr(NewLHS, NewRHS);
3777
3778 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3779 // changes). Make sure we get a new one.
3780 IP = nullptr;
3781 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3782 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3783 LHS, RHS);
3784 UniqueSCEVs.InsertNode(S, IP);
3785 S->computeAndSetCanonical(*this);
3786 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3787 return S;
3788}
3789
3790APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3791 APInt A = C1->getAPInt().abs();
3792 APInt B = C2->getAPInt().abs();
3793 uint32_t ABW = A.getBitWidth();
3794 uint32_t BBW = B.getBitWidth();
3795
3796 if (ABW > BBW)
3797 B = B.zext(ABW);
3798 else if (ABW < BBW)
3799 A = A.zext(BBW);
3800
3801 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3802}
3803
3804/// Get a canonical unsigned division expression, or something simpler if
3805/// possible. There is no representation for an exact udiv in SCEV IR, but we
3806/// can attempt to optimize it prior to construction.
3808 // Currently there is no exact specific logic.
3809
3810 return getUDivExpr(LHS, RHS);
3811}
3812
3813/// Get an add recurrence expression for the specified loop. Simplify the
3814/// expression as much as possible.
3816 const Loop *L,
3817 SCEV::NoWrapFlags Flags) {
3818 SmallVector<SCEVUse, 4> Operands;
3819 Operands.push_back(Start);
3820 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3821 if (StepChrec->getLoop() == L) {
3822 append_range(Operands, StepChrec->operands());
3823 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3824 }
3825
3826 Operands.push_back(Step);
3827 return getAddRecExpr(Operands, L, Flags);
3828}
3829
3830/// Get an add recurrence expression for the specified loop. Simplify the
3831/// expression as much as possible.
3833 const Loop *L,
3834 SCEV::NoWrapFlags Flags) {
3835 if (Operands.size() == 1) return Operands[0];
3836#ifndef NDEBUG
3837 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3838 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3839 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3840 "SCEVAddRecExpr operand types don't match!");
3841 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3842 }
3843 for (const SCEV *Op : Operands)
3845 "SCEVAddRecExpr operand is not available at loop entry!");
3846#endif
3847
3848 if (Operands.back()->isZero()) {
3849 Operands.pop_back();
3850 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3851 }
3852
3853 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3854 // use that information to infer NUW and NSW flags. However, computing a
3855 // BE count requires calling getAddRecExpr, so we may not yet have a
3856 // meaningful BE count at this point (and if we don't, we'd be stuck
3857 // with a SCEVCouldNotCompute as the cached BE count).
3858
3859 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3860
3861 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3862 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3863 const Loop *NestedLoop = NestedAR->getLoop();
3864 if (L->contains(NestedLoop)
3865 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3866 : (!NestedLoop->contains(L) &&
3867 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3868 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3869 Operands[0] = NestedAR->getStart();
3870 // AddRecs require their operands be loop-invariant with respect to their
3871 // loops. Don't perform this transformation if it would break this
3872 // requirement.
3873 bool AllInvariant = all_of(
3874 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3875
3876 if (AllInvariant) {
3877 // Create a recurrence for the outer loop with the same step size.
3878 //
3879 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3880 // inner recurrence has the same property.
3881 SCEV::NoWrapFlags OuterFlags =
3882 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3883
3884 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3885 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3886 return isLoopInvariant(Op, NestedLoop);
3887 });
3888
3889 if (AllInvariant) {
3890 // Ok, both add recurrences are valid after the transformation.
3891 //
3892 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3893 // the outer recurrence has the same property.
3894 SCEV::NoWrapFlags InnerFlags =
3895 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3896 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3897 }
3898 }
3899 // Reset Operands to its original state.
3900 Operands[0] = NestedAR;
3901 }
3902 }
3903
3904 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3905 // already have one, otherwise create a new one.
3906 return getOrCreateAddRecExpr(Operands, L, Flags);
3907}
3908
3910 ArrayRef<SCEVUse> IndexExprs) {
3911 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3912 // getSCEV(Base)->getType() has the same address space as Base->getType()
3913 // because SCEV::getType() preserves the address space.
3914 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3915 if (NW != GEPNoWrapFlags::none()) {
3916 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3917 // but to do that, we have to ensure that said flag is valid in the entire
3918 // defined scope of the SCEV.
3919 // TODO: non-instructions have global scope. We might be able to prove
3920 // some global scope cases
3921 auto *GEPI = dyn_cast<Instruction>(GEP);
3922 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3923 NW = GEPNoWrapFlags::none();
3924 }
3925
3926 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3927}
3928
3930 ArrayRef<SCEVUse> IndexExprs,
3931 Type *SrcElementTy, GEPNoWrapFlags NW) {
3933 if (NW.hasNoUnsignedSignedWrap())
3934 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3935 if (NW.hasNoUnsignedWrap())
3936 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3937
3938 Type *CurTy = BaseExpr->getType();
3939 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3940 bool FirstIter = true;
3942 for (SCEVUse IndexExpr : IndexExprs) {
3943 // Compute the (potentially symbolic) offset in bytes for this index.
3944 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3945 // For a struct, add the member offset.
3946 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3947 unsigned FieldNo = Index->getZExtValue();
3948 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3949 Offsets.push_back(FieldOffset);
3950
3951 // Update CurTy to the type of the field at Index.
3952 CurTy = STy->getTypeAtIndex(Index);
3953 } else {
3954 // Update CurTy to its element type.
3955 if (FirstIter) {
3956 assert(isa<PointerType>(CurTy) &&
3957 "The first index of a GEP indexes a pointer");
3958 CurTy = SrcElementTy;
3959 FirstIter = false;
3960 } else {
3962 }
3963 // For an array, add the element offset, explicitly scaled.
3964 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3965 // Getelementptr indices are signed.
3966 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3967
3968 // Multiply the index by the element size to compute the element offset.
3969 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3970 Offsets.push_back(LocalOffset);
3971 }
3972 }
3973
3974 // Handle degenerate case of GEP without offsets.
3975 if (Offsets.empty())
3976 return BaseExpr;
3977
3978 // Add the offsets together, assuming nsw if inbounds.
3979 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3980 // Add the base address and the offset. We cannot use the nsw flag, as the
3981 // base address is unsigned. However, if we know that the offset is
3982 // non-negative, we can use nuw.
3983 bool NUW = NW.hasNoUnsignedWrap() ||
3986 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3987 assert(BaseExpr->getType() == GEPExpr->getType() &&
3988 "GEP should not change type mid-flight.");
3989 return GEPExpr;
3990}
3991
3992SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3995 ID.AddInteger(SCEVType);
3996 for (const SCEV *Op : Ops)
3997 ID.AddPointer(Op);
3998 void *IP = nullptr;
3999 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4000}
4001
4002SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
4005 ID.AddInteger(SCEVType);
4006 for (const SCEV *Op : Ops)
4007 ID.AddPointer(Op);
4008 void *IP = nullptr;
4009 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4010}
4011
4012const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
4014 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
4015}
4016
4019 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4020 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4021 if (Ops.size() == 1) return Ops[0];
4022#ifndef NDEBUG
4023 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4024 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4025 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4026 "Operand types don't match!");
4027 assert(Ops[0]->getType()->isPointerTy() ==
4028 Ops[i]->getType()->isPointerTy() &&
4029 "min/max should be consistently pointerish");
4030 }
4031#endif
4032
4033 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4034 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4035
4036 const SCEV *Folded = constantFoldAndGroupOps(
4037 *this, LI, DT, Ops,
4038 [&](const APInt &C1, const APInt &C2) {
4039 switch (Kind) {
4040 case scSMaxExpr:
4041 return APIntOps::smax(C1, C2);
4042 case scSMinExpr:
4043 return APIntOps::smin(C1, C2);
4044 case scUMaxExpr:
4045 return APIntOps::umax(C1, C2);
4046 case scUMinExpr:
4047 return APIntOps::umin(C1, C2);
4048 default:
4049 llvm_unreachable("Unknown SCEV min/max opcode");
4050 }
4051 },
4052 [&](const APInt &C) {
4053 // identity
4054 if (IsMax)
4055 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4056 else
4057 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4058 },
4059 [&](const APInt &C) {
4060 // absorber
4061 if (IsMax)
4062 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4063 else
4064 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4065 });
4066 if (Folded)
4067 return Folded;
4068
4069 // Check if we have created the same expression before.
4070 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4071 return S;
4072 }
4073
4074 // Find the first operation of the same kind
4075 unsigned Idx = 0;
4076 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4077 ++Idx;
4078
4079 // Check to see if one of the operands is of the same kind. If so, expand its
4080 // operands onto our operand list, and recurse to simplify.
4081 if (Idx < Ops.size()) {
4082 bool DeletedAny = false;
4083 while (Ops[Idx]->getSCEVType() == Kind) {
4084 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4085 Ops.erase(Ops.begin()+Idx);
4086 append_range(Ops, SMME->operands());
4087 DeletedAny = true;
4088 }
4089
4090 if (DeletedAny)
4091 return getMinMaxExpr(Kind, Ops);
4092 }
4093
4094 // Okay, check to see if the same value occurs in the operand list twice. If
4095 // so, delete one. Since we sorted the list, these values are required to
4096 // be adjacent.
4101 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4102 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4103 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4104 if (Ops[i] == Ops[i + 1] ||
4105 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4106 // X op Y op Y --> X op Y
4107 // X op Y --> X, if we know X, Y are ordered appropriately
4108 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4109 --i;
4110 --e;
4111 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4112 Ops[i + 1])) {
4113 // X op Y --> Y, if we know X, Y are ordered appropriately
4114 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4115 --i;
4116 --e;
4117 }
4118 }
4119
4120 if (Ops.size() == 1) return Ops[0];
4121
4122 assert(!Ops.empty() && "Reduced smax down to nothing!");
4123
4124 // Okay, it looks like we really DO need an expr. Check to see if we
4125 // already have one, otherwise create a new one.
4127 ID.AddInteger(Kind);
4128 for (const SCEV *Op : Ops)
4129 ID.AddPointer(Op);
4130 void *IP = nullptr;
4131 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4132 if (ExistingSCEV)
4133 return ExistingSCEV;
4134 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4136 SCEV *S = new (SCEVAllocator)
4137 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4138
4139 UniqueSCEVs.InsertNode(S, IP);
4140 S->computeAndSetCanonical(*this);
4141 registerUser(S, Ops);
4142 return S;
4143}
4144
4145namespace {
4146
4147class SCEVSequentialMinMaxDeduplicatingVisitor final
4148 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4149 std::optional<const SCEV *>> {
4150 using RetVal = std::optional<const SCEV *>;
4152
4153 ScalarEvolution &SE;
4154 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4155 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4157
4158 bool canRecurseInto(SCEVTypes Kind) const {
4159 // We can only recurse into the SCEV expression of the same effective type
4160 // as the type of our root SCEV expression.
4161 return RootKind == Kind || NonSequentialRootKind == Kind;
4162 };
4163
4164 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4166 "Only for min/max expressions.");
4167 SCEVTypes Kind = S->getSCEVType();
4168
4169 if (!canRecurseInto(Kind))
4170 return S;
4171
4172 auto *NAry = cast<SCEVNAryExpr>(S);
4173 SmallVector<SCEVUse> NewOps;
4174 bool Changed = visit(Kind, NAry->operands(), NewOps);
4175
4176 if (!Changed)
4177 return S;
4178 if (NewOps.empty())
4179 return std::nullopt;
4180
4182 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4183 : SE.getMinMaxExpr(Kind, NewOps);
4184 }
4185
4186 RetVal visit(const SCEV *S) {
4187 // Has the whole operand been seen already?
4188 if (!SeenOps.insert(S).second)
4189 return std::nullopt;
4190 return Base::visit(S);
4191 }
4192
4193public:
4194 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4195 SCEVTypes RootKind)
4196 : SE(SE), RootKind(RootKind),
4197 NonSequentialRootKind(
4198 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4199 RootKind)) {}
4200
4201 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4202 SmallVectorImpl<SCEVUse> &NewOps) {
4203 bool Changed = false;
4205 Ops.reserve(OrigOps.size());
4206
4207 for (const SCEV *Op : OrigOps) {
4208 RetVal NewOp = visit(Op);
4209 if (NewOp != Op)
4210 Changed = true;
4211 if (NewOp)
4212 Ops.emplace_back(*NewOp);
4213 }
4214
4215 if (Changed)
4216 NewOps = std::move(Ops);
4217 return Changed;
4218 }
4219
4220 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4221
4222 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4223
4224 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4225
4226 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4227
4228 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4229
4230 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4231
4232 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4233
4234 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4235
4236 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4237
4238 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4239
4240 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4241
4242 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4243 return visitAnyMinMaxExpr(Expr);
4244 }
4245
4246 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4247 return visitAnyMinMaxExpr(Expr);
4248 }
4249
4250 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4251 return visitAnyMinMaxExpr(Expr);
4252 }
4253
4254 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4255 return visitAnyMinMaxExpr(Expr);
4256 }
4257
4258 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4259 return visitAnyMinMaxExpr(Expr);
4260 }
4261
4262 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4263
4264 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4265};
4266
4267} // namespace
4268
4270 switch (Kind) {
4271 case scConstant:
4272 case scVScale:
4273 case scTruncate:
4274 case scZeroExtend:
4275 case scSignExtend:
4276 case scPtrToAddr:
4277 case scPtrToInt:
4278 case scAddExpr:
4279 case scMulExpr:
4280 case scUDivExpr:
4281 case scAddRecExpr:
4282 case scUMaxExpr:
4283 case scSMaxExpr:
4284 case scUMinExpr:
4285 case scSMinExpr:
4286 case scUnknown:
4287 // If any operand is poison, the whole expression is poison.
4288 return true;
4290 // FIXME: if the *first* operand is poison, the whole expression is poison.
4291 return false; // Pessimistically, say that it does not propagate poison.
4292 case scCouldNotCompute:
4293 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4294 }
4295 llvm_unreachable("Unknown SCEV kind!");
4296}
4297
4298namespace {
4299// The only way poison may be introduced in a SCEV expression is from a
4300// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4301// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4302// introduce poison -- they encode guaranteed, non-speculated knowledge.
4303//
4304// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4305// with the notable exception of umin_seq, where only poison from the first
4306// operand is (unconditionally) propagated.
4307struct SCEVPoisonCollector {
4308 bool LookThroughMaybePoisonBlocking;
4309 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4310 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4311 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4312
4313 bool follow(const SCEV *S) {
4314 if (!LookThroughMaybePoisonBlocking &&
4316 return false;
4317
4318 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4319 if (!isGuaranteedNotToBePoison(SU->getValue()))
4320 MaybePoison.insert(SU);
4321 }
4322 return true;
4323 }
4324 bool isDone() const { return false; }
4325};
4326} // namespace
4327
4328/// Return true if V is poison given that AssumedPoison is already poison.
4329static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4330 // First collect all SCEVs that might result in AssumedPoison to be poison.
4331 // We need to look through potentially poison-blocking operations here,
4332 // because we want to find all SCEVs that *might* result in poison, not only
4333 // those that are *required* to.
4334 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4335 visitAll(AssumedPoison, PC1);
4336
4337 // AssumedPoison is never poison. As the assumption is false, the implication
4338 // is true. Don't bother walking the other SCEV in this case.
4339 if (PC1.MaybePoison.empty())
4340 return true;
4341
4342 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4343 // as well. We cannot look through potentially poison-blocking operations
4344 // here, as their arguments only *may* make the result poison.
4345 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4346 visitAll(S, PC2);
4347
4348 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4349 // it will also make S poison by being part of PC2.MaybePoison.
4350 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4351}
4352
4354 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4355 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4356 visitAll(S, PC);
4357 for (const SCEVUnknown *SU : PC.MaybePoison)
4358 Result.insert(SU->getValue());
4359}
4360
4362 const SCEV *S, Instruction *I,
4363 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4364 // If the instruction cannot be poison, it's always safe to reuse.
4366 return true;
4367
4368 // Otherwise, it is possible that I is more poisonous that S. Collect the
4369 // poison-contributors of S, and then check whether I has any additional
4370 // poison-contributors. Poison that is contributed through poison-generating
4371 // flags is handled by dropping those flags instead.
4373 getPoisonGeneratingValues(PoisonVals, S);
4374
4375 SmallVector<Value *> Worklist;
4377 Worklist.push_back(I);
4378 while (!Worklist.empty()) {
4379 Value *V = Worklist.pop_back_val();
4380 if (!Visited.insert(V).second)
4381 continue;
4382
4383 // Avoid walking large instruction graphs.
4384 if (Visited.size() > 16)
4385 return false;
4386
4387 // Either the value can't be poison, or the S would also be poison if it
4388 // is.
4389 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4390 continue;
4391
4392 auto *I = dyn_cast<Instruction>(V);
4393 if (!I)
4394 return false;
4395
4396 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4397 // can't replace an arbitrary add with disjoint or, even if we drop the
4398 // flag. We would need to convert the or into an add.
4399 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4400 if (PDI->isDisjoint())
4401 return false;
4402
4403 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4404 // because SCEV currently assumes it can't be poison. Remove this special
4405 // case once we proper model when vscale can be poison.
4406 if (auto *II = dyn_cast<IntrinsicInst>(I);
4407 II && II->getIntrinsicID() == Intrinsic::vscale)
4408 continue;
4409
4410 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4411 return false;
4412
4413 // If the instruction can't create poison, we can recurse to its operands.
4414 if (I->hasPoisonGeneratingAnnotations())
4415 DropPoisonGeneratingInsts.push_back(I);
4416
4417 llvm::append_range(Worklist, I->operands());
4418 }
4419 return true;
4420}
4421
4422const SCEV *
4425 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4426 "Not a SCEVSequentialMinMaxExpr!");
4427 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4428 if (Ops.size() == 1)
4429 return Ops[0];
4430#ifndef NDEBUG
4431 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4432 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4433 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4434 "Operand types don't match!");
4435 assert(Ops[0]->getType()->isPointerTy() ==
4436 Ops[i]->getType()->isPointerTy() &&
4437 "min/max should be consistently pointerish");
4438 }
4439#endif
4440
4441 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4442 // so we can *NOT* do any kind of sorting of the expressions!
4443
4444 // Check if we have created the same expression before.
4445 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4446 return S;
4447
4448 // FIXME: there are *some* simplifications that we can do here.
4449
4450 // Keep only the first instance of an operand.
4451 {
4452 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4453 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4454 if (Changed)
4455 return getSequentialMinMaxExpr(Kind, Ops);
4456 }
4457
4458 // Check to see if one of the operands is of the same kind. If so, expand its
4459 // operands onto our operand list, and recurse to simplify.
4460 {
4461 unsigned Idx = 0;
4462 bool DeletedAny = false;
4463 while (Idx < Ops.size()) {
4464 if (Ops[Idx]->getSCEVType() != Kind) {
4465 ++Idx;
4466 continue;
4467 }
4468 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4469 Ops.erase(Ops.begin() + Idx);
4470 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4471 SMME->operands().end());
4472 DeletedAny = true;
4473 }
4474
4475 if (DeletedAny)
4476 return getSequentialMinMaxExpr(Kind, Ops);
4477 }
4478
4479 const SCEV *SaturationPoint;
4481 switch (Kind) {
4483 SaturationPoint = getZero(Ops[0]->getType());
4484 Pred = ICmpInst::ICMP_ULE;
4485 break;
4486 default:
4487 llvm_unreachable("Not a sequential min/max type.");
4488 }
4489
4490 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4491 if (!isGuaranteedNotToCauseUB(Ops[i]))
4492 continue;
4493 // We can replace %x umin_seq %y with %x umin %y if either:
4494 // * %y being poison implies %x is also poison.
4495 // * %x cannot be the saturating value (e.g. zero for umin).
4496 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4497 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4498 SaturationPoint)) {
4499 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4500 Ops[i - 1] = getMinMaxExpr(
4502 SeqOps);
4503 Ops.erase(Ops.begin() + i);
4504 return getSequentialMinMaxExpr(Kind, Ops);
4505 }
4506 // Fold %x umin_seq %y to %x if %x ule %y.
4507 // TODO: We might be able to prove the predicate for a later operand.
4508 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4509 Ops.erase(Ops.begin() + i);
4510 return getSequentialMinMaxExpr(Kind, Ops);
4511 }
4512 }
4513
4514 // Okay, it looks like we really DO need an expr. Check to see if we
4515 // already have one, otherwise create a new one.
4517 ID.AddInteger(Kind);
4518 for (const SCEV *Op : Ops)
4519 ID.AddPointer(Op);
4520 void *IP = nullptr;
4521 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4522 if (ExistingSCEV)
4523 return ExistingSCEV;
4524
4525 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4527 SCEV *S = new (SCEVAllocator)
4528 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4529
4530 UniqueSCEVs.InsertNode(S, IP);
4531 S->computeAndSetCanonical(*this);
4532 registerUser(S, Ops);
4533 return S;
4534}
4535
4540
4544
4549
4553
4558
4562
4564 bool Sequential) {
4565 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4566 return getUMinExpr(Ops, Sequential);
4567}
4568
4574
4575const SCEV *
4577 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4578 if (Size.isScalable())
4579 Res = getMulExpr(Res, getVScale(IntTy));
4580 return Res;
4581}
4582
4584 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4585}
4586
4588 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4589}
4590
4592 StructType *STy,
4593 unsigned FieldNo) {
4594 // We can bypass creating a target-independent constant expression and then
4595 // folding it back into a ConstantInt. This is just a compile-time
4596 // optimization.
4597 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4598 assert(!SL->getSizeInBits().isScalable() &&
4599 "Cannot get offset for structure containing scalable vector types");
4600 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4601}
4602
4604 // Don't attempt to do anything other than create a SCEVUnknown object
4605 // here. createSCEV only calls getUnknown after checking for all other
4606 // interesting possibilities, and any other code that calls getUnknown
4607 // is doing so in order to hide a value from SCEV canonicalization.
4608
4610 ID.AddInteger(scUnknown);
4611 ID.AddPointer(V);
4612 void *IP = nullptr;
4613 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4614 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4615 "Stale SCEVUnknown in uniquing map!");
4616 return S;
4617 }
4618 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4619 FirstUnknown);
4620 FirstUnknown = cast<SCEVUnknown>(S);
4621 UniqueSCEVs.InsertNode(S, IP);
4622 S->computeAndSetCanonical(*this);
4623 return S;
4624}
4625
4626//===----------------------------------------------------------------------===//
4627// Basic SCEV Analysis and PHI Idiom Recognition Code
4628//
4629
4630/// Test if values of the given type are analyzable within the SCEV
4631/// framework. This primarily includes integer types, and it can optionally
4632/// include pointer types if the ScalarEvolution class has access to
4633/// target-specific information.
4635 // Integers and pointers are always SCEVable.
4636 return Ty->isIntOrPtrTy();
4637}
4638
4639/// Return the size in bits of the specified type, for which isSCEVable must
4640/// return true.
4642 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4643 if (Ty->isPointerTy())
4645 return getDataLayout().getTypeSizeInBits(Ty);
4646}
4647
4648/// Return a type with the same bitwidth as the given type and which represents
4649/// how SCEV will treat the given type, for which isSCEVable must return
4650/// true. For pointer types, this is the pointer index sized integer type.
4652 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4653
4654 if (Ty->isIntegerTy())
4655 return Ty;
4656
4657 // The only other support type is pointer.
4658 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4659 return getDataLayout().getIndexType(Ty);
4660}
4661
4663 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4664}
4665
4667 const SCEV *B) {
4668 /// For a valid use point to exist, the defining scope of one operand
4669 /// must dominate the other.
4670 bool PreciseA, PreciseB;
4671 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4672 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4673 if (!PreciseA || !PreciseB)
4674 // Can't tell.
4675 return false;
4676 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4677 DT.dominates(ScopeB, ScopeA);
4678}
4679
4681 return CouldNotCompute.get();
4682}
4683
4684bool ScalarEvolution::checkValidity(const SCEV *S) const {
4685 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4686 auto *SU = dyn_cast<SCEVUnknown>(S);
4687 return SU && SU->getValue() == nullptr;
4688 });
4689
4690 return !ContainsNulls;
4691}
4692
4694 HasRecMapType::iterator I = HasRecMap.find(S);
4695 if (I != HasRecMap.end())
4696 return I->second;
4697
4698 bool FoundAddRec =
4699 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4700 HasRecMap.insert({S, FoundAddRec});
4701 return FoundAddRec;
4702}
4703
4704/// Return the ValueOffsetPair set for \p S. \p S can be represented
4705/// by the value and offset from any ValueOffsetPair in the set.
4706ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4707 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4708 if (SI == ExprValueMap.end())
4709 return {};
4710 return SI->second.getArrayRef();
4711}
4712
4713/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4714/// cannot be used separately. eraseValueFromMap should be used to remove
4715/// V from ValueExprMap and ExprValueMap at the same time.
4716void ScalarEvolution::eraseValueFromMap(Value *V) {
4717 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4718 if (I != ValueExprMap.end()) {
4719 auto EVIt = ExprValueMap.find(I->second);
4720 bool Removed = EVIt->second.remove(V);
4721 (void) Removed;
4722 assert(Removed && "Value not in ExprValueMap?");
4723 ValueExprMap.erase(I);
4724 }
4725}
4726
4727void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4728 // A recursive query may have already computed the SCEV. It should be
4729 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4730 // inferred nowrap flags.
4731 auto It = ValueExprMap.find_as(V);
4732 if (It == ValueExprMap.end()) {
4733 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4734 ExprValueMap[S].insert(V);
4735 }
4736}
4737
4738/// Return an existing SCEV if it exists, otherwise analyze the expression and
4739/// create a new one.
4741 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4742
4743 if (const SCEV *S = getExistingSCEV(V))
4744 return S;
4745 return createSCEVIter(V);
4746}
4747
4749 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4750
4751 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4752 if (I != ValueExprMap.end()) {
4753 const SCEV *S = I->second;
4754 assert(checkValidity(S) &&
4755 "existing SCEV has not been properly invalidated");
4756 return S;
4757 }
4758 return nullptr;
4759}
4760
4761/// Return a SCEV corresponding to -V = -1*V
4763 SCEV::NoWrapFlags Flags) {
4764 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4765 return getConstant(
4766 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4767
4768 Type *Ty = V->getType();
4769 Ty = getEffectiveSCEVType(Ty);
4770 return getMulExpr(V, getMinusOne(Ty), Flags);
4771}
4772
4773/// If Expr computes ~A, return A else return nullptr
4774static const SCEV *MatchNotExpr(const SCEV *Expr) {
4775 const SCEV *MulOp;
4776 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4777 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4778 return MulOp;
4779 return nullptr;
4780}
4781
4782/// Return a SCEV corresponding to ~V = -1-V
4784 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4785
4786 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4787 return getConstant(
4788 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4789
4790 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4791 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4792 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4793 SmallVector<SCEVUse, 2> MatchedOperands;
4794 for (const SCEV *Operand : MME->operands()) {
4795 const SCEV *Matched = MatchNotExpr(Operand);
4796 if (!Matched)
4797 return (const SCEV *)nullptr;
4798 MatchedOperands.push_back(Matched);
4799 }
4800 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4801 MatchedOperands);
4802 };
4803 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4804 return Replaced;
4805 }
4806
4807 Type *Ty = V->getType();
4808 Ty = getEffectiveSCEVType(Ty);
4809 return getMinusSCEV(getMinusOne(Ty), V);
4810}
4811
4813 assert(P->getType()->isPointerTy());
4814
4815 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4816 // The base of an AddRec is the first operand.
4817 SmallVector<SCEVUse> Ops{AddRec->operands()};
4818 Ops[0] = removePointerBase(Ops[0]);
4819 // Don't try to transfer nowrap flags for now. We could in some cases
4820 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4821 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4822 }
4823 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4824 // The base of an Add is the pointer operand.
4825 SmallVector<SCEVUse> Ops{Add->operands()};
4826 SCEVUse *PtrOp = nullptr;
4827 for (SCEVUse &AddOp : Ops) {
4828 if (AddOp->getType()->isPointerTy()) {
4829 assert(!PtrOp && "Cannot have multiple pointer ops");
4830 PtrOp = &AddOp;
4831 }
4832 }
4833 *PtrOp = removePointerBase(*PtrOp);
4834 // Don't try to transfer nowrap flags for now. We could in some cases
4835 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4836 return getAddExpr(Ops);
4837 }
4838 // Any other expression must be a pointer base.
4839 return getZero(P->getType());
4840}
4841
4843 SCEV::NoWrapFlags Flags,
4844 unsigned Depth) {
4845 // Fast path: X - X --> 0.
4846 if (LHS == RHS)
4847 return getZero(LHS->getType());
4848
4849 // If we subtract two pointers with different pointer bases, bail.
4850 // Eventually, we're going to add an assertion to getMulExpr that we
4851 // can't multiply by a pointer.
4852 if (RHS->getType()->isPointerTy()) {
4853 if (!LHS->getType()->isPointerTy() ||
4854 getPointerBase(LHS) != getPointerBase(RHS))
4855 return getCouldNotCompute();
4856 LHS = removePointerBase(LHS);
4857 RHS = removePointerBase(RHS);
4858 }
4859
4860 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4861 // makes it so that we cannot make much use of NUW.
4862 auto AddFlags = SCEV::FlagAnyWrap;
4863 const bool RHSIsNotMinSigned =
4865 if (hasFlags(Flags, SCEV::FlagNSW)) {
4866 // Let M be the minimum representable signed value. Then (-1)*RHS
4867 // signed-wraps if and only if RHS is M. That can happen even for
4868 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4869 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4870 // (-1)*RHS, we need to prove that RHS != M.
4871 //
4872 // If LHS is non-negative and we know that LHS - RHS does not
4873 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4874 // either by proving that RHS > M or that LHS >= 0.
4875 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4876 AddFlags = SCEV::FlagNSW;
4877 }
4878 }
4879
4880 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4881 // RHS is NSW and LHS >= 0.
4882 //
4883 // The difficulty here is that the NSW flag may have been proven
4884 // relative to a loop that is to be found in a recurrence in LHS and
4885 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4886 // larger scope than intended.
4887 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4888
4889 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4890}
4891
4893 unsigned Depth) {
4894 Type *SrcTy = V->getType();
4895 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4896 "Cannot truncate or zero extend with non-integer arguments!");
4897 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4898 return V; // No conversion
4899 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4900 return getTruncateExpr(V, Ty, Depth);
4901 return getZeroExtendExpr(V, Ty, Depth);
4902}
4903
4905 unsigned Depth) {
4906 Type *SrcTy = V->getType();
4907 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4908 "Cannot truncate or zero extend with non-integer arguments!");
4909 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4910 return V; // No conversion
4911 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4912 return getTruncateExpr(V, Ty, Depth);
4913 return getSignExtendExpr(V, Ty, Depth);
4914}
4915
4916const SCEV *
4918 Type *SrcTy = V->getType();
4919 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4920 "Cannot noop or zero extend with non-integer arguments!");
4922 "getNoopOrZeroExtend cannot truncate!");
4923 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4924 return V; // No conversion
4925 return getZeroExtendExpr(V, Ty);
4926}
4927
4928const SCEV *
4930 Type *SrcTy = V->getType();
4931 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4932 "Cannot noop or sign extend with non-integer arguments!");
4934 "getNoopOrSignExtend cannot truncate!");
4935 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4936 return V; // No conversion
4937 return getSignExtendExpr(V, Ty);
4938}
4939
4940const SCEV *
4942 Type *SrcTy = V->getType();
4943 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4944 "Cannot noop or any extend with non-integer arguments!");
4946 "getNoopOrAnyExtend cannot truncate!");
4947 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4948 return V; // No conversion
4949 return getAnyExtendExpr(V, Ty);
4950}
4951
4952const SCEV *
4954 Type *SrcTy = V->getType();
4955 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4956 "Cannot truncate or noop with non-integer arguments!");
4958 "getTruncateOrNoop cannot extend!");
4959 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4960 return V; // No conversion
4961 return getTruncateExpr(V, Ty);
4962}
4963
4965 const SCEV *RHS) {
4966 const SCEV *PromotedLHS = LHS;
4967 const SCEV *PromotedRHS = RHS;
4968
4969 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4970 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4971 else
4972 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4973
4974 return getUMaxExpr(PromotedLHS, PromotedRHS);
4975}
4976
4978 const SCEV *RHS,
4979 bool Sequential) {
4980 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4981 return getUMinFromMismatchedTypes(Ops, Sequential);
4982}
4983
4984const SCEV *
4986 bool Sequential) {
4987 assert(!Ops.empty() && "At least one operand must be!");
4988 // Trivial case.
4989 if (Ops.size() == 1)
4990 return Ops[0];
4991
4992 // Find the max type first.
4993 Type *MaxType = nullptr;
4994 for (SCEVUse S : Ops)
4995 if (MaxType)
4996 MaxType = getWiderType(MaxType, S->getType());
4997 else
4998 MaxType = S->getType();
4999 assert(MaxType && "Failed to find maximum type!");
5000
5001 // Extend all ops to max type.
5002 SmallVector<SCEVUse, 2> PromotedOps;
5003 for (SCEVUse S : Ops)
5004 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
5005
5006 // Generate umin.
5007 return getUMinExpr(PromotedOps, Sequential);
5008}
5009
5011 // A pointer operand may evaluate to a nonpointer expression, such as null.
5012 if (!V->getType()->isPointerTy())
5013 return V;
5014
5015 while (true) {
5016 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5017 V = AddRec->getStart();
5018 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5019 const SCEV *PtrOp = nullptr;
5020 for (const SCEV *AddOp : Add->operands()) {
5021 if (AddOp->getType()->isPointerTy()) {
5022 assert(!PtrOp && "Cannot have multiple pointer ops");
5023 PtrOp = AddOp;
5024 }
5025 }
5026 assert(PtrOp && "Must have pointer op");
5027 V = PtrOp;
5028 } else // Not something we can look further into.
5029 return V;
5030 }
5031}
5032
5033/// Push users of the given Instruction onto the given Worklist.
5037 // Push the def-use children onto the Worklist stack.
5038 for (User *U : I->users()) {
5039 auto *UserInsn = cast<Instruction>(U);
5040 if (Visited.insert(UserInsn).second)
5041 Worklist.push_back(UserInsn);
5042 }
5043}
5044
5045namespace {
5046
5047/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5048/// expression in case its Loop is L. If it is not L then
5049/// if IgnoreOtherLoops is true then use AddRec itself
5050/// otherwise rewrite cannot be done.
5051/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5052class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5053public:
5054 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5055 bool IgnoreOtherLoops = true) {
5056 SCEVInitRewriter Rewriter(L, SE);
5057 const SCEV *Result = Rewriter.visit(S);
5058 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5059 return SE.getCouldNotCompute();
5060 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5061 ? SE.getCouldNotCompute()
5062 : Result;
5063 }
5064
5065 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5066 if (!SE.isLoopInvariant(Expr, L))
5067 SeenLoopVariantSCEVUnknown = true;
5068 return Expr;
5069 }
5070
5071 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5072 // Only re-write AddRecExprs for this loop.
5073 if (Expr->getLoop() == L)
5074 return Expr->getStart();
5075 SeenOtherLoops = true;
5076 return Expr;
5077 }
5078
5079 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5080
5081 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5082
5083private:
5084 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5085 : SCEVRewriteVisitor(SE), L(L) {}
5086
5087 const Loop *L;
5088 bool SeenLoopVariantSCEVUnknown = false;
5089 bool SeenOtherLoops = false;
5090};
5091
5092/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5093/// increment expression in case its Loop is L. If it is not L then
5094/// use AddRec itself.
5095/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5096class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5097public:
5098 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5099 SCEVPostIncRewriter Rewriter(L, SE);
5100 const SCEV *Result = Rewriter.visit(S);
5101 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5102 ? SE.getCouldNotCompute()
5103 : Result;
5104 }
5105
5106 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5107 if (!SE.isLoopInvariant(Expr, L))
5108 SeenLoopVariantSCEVUnknown = true;
5109 return Expr;
5110 }
5111
5112 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5113 // Only re-write AddRecExprs for this loop.
5114 if (Expr->getLoop() == L)
5115 return Expr->getPostIncExpr(SE);
5116 SeenOtherLoops = true;
5117 return Expr;
5118 }
5119
5120 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5121
5122 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5123
5124private:
5125 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5126 : SCEVRewriteVisitor(SE), L(L) {}
5127
5128 const Loop *L;
5129 bool SeenLoopVariantSCEVUnknown = false;
5130 bool SeenOtherLoops = false;
5131};
5132
5133/// This class evaluates the compare condition by matching it against the
5134/// condition of loop latch. If there is a match we assume a true value
5135/// for the condition while building SCEV nodes.
5136class SCEVBackedgeConditionFolder
5137 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5138public:
5139 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5140 ScalarEvolution &SE) {
5141 bool IsPosBECond = false;
5142 Value *BECond = nullptr;
5143 if (BasicBlock *Latch = L->getLoopLatch()) {
5144 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5145 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5146 "Both outgoing branches should not target same header!");
5147 BECond = BI->getCondition();
5148 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5149 } else {
5150 return S;
5151 }
5152 }
5153 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5154 return Rewriter.visit(S);
5155 }
5156
5157 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5158 const SCEV *Result = Expr;
5159 bool InvariantF = SE.isLoopInvariant(Expr, L);
5160
5161 if (!InvariantF) {
5163 switch (I->getOpcode()) {
5164 case Instruction::Select: {
5165 SelectInst *SI = cast<SelectInst>(I);
5166 std::optional<const SCEV *> Res =
5167 compareWithBackedgeCondition(SI->getCondition());
5168 if (Res) {
5169 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5170 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5171 }
5172 break;
5173 }
5174 default: {
5175 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5176 if (Res)
5177 Result = *Res;
5178 break;
5179 }
5180 }
5181 }
5182 return Result;
5183 }
5184
5185private:
5186 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5187 bool IsPosBECond, ScalarEvolution &SE)
5188 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5189 IsPositiveBECond(IsPosBECond) {}
5190
5191 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5192
5193 const Loop *L;
5194 /// Loop back condition.
5195 Value *BackedgeCond = nullptr;
5196 /// Set to true if loop back is on positive branch condition.
5197 bool IsPositiveBECond;
5198};
5199
5200std::optional<const SCEV *>
5201SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5202
5203 // If value matches the backedge condition for loop latch,
5204 // then return a constant evolution node based on loopback
5205 // branch taken.
5206 if (BackedgeCond == IC)
5207 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5209 return std::nullopt;
5210}
5211
5212class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5213public:
5214 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5215 ScalarEvolution &SE) {
5216 SCEVShiftRewriter Rewriter(L, SE);
5217 const SCEV *Result = Rewriter.visit(S);
5218 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5219 }
5220
5221 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5222 // Only allow AddRecExprs for this loop.
5223 if (!SE.isLoopInvariant(Expr, L))
5224 Valid = false;
5225 return Expr;
5226 }
5227
5228 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5229 if (Expr->getLoop() == L && Expr->isAffine())
5230 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5231 Valid = false;
5232 return Expr;
5233 }
5234
5235 bool isValid() { return Valid; }
5236
5237private:
5238 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5239 : SCEVRewriteVisitor(SE), L(L) {}
5240
5241 const Loop *L;
5242 bool Valid = true;
5243};
5244
5245} // end anonymous namespace
5246
5247void ScalarEvolution::inferNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5248 if (!AR->isAffine())
5249 return;
5250
5251 // Force computation of ranges, which will also perform range-based flag
5252 // inference.
5253 if (!AR->hasNoSignedWrap())
5254 (void)getSignedRange(AR);
5255
5256 if (!AR->hasNoUnsignedWrap())
5257 (void)getUnsignedRange(AR);
5258
5259 if (!AR->hasNoSelfWrap()) {
5260 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5261 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5262 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5263 const APInt &BECountAP = BECountMax->getAPInt();
5264 unsigned NoOverflowBitWidth =
5265 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5266 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5267 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
5268 }
5269 }
5270}
5271
5273ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5275
5276 if (AR->hasNoSignedWrap())
5277 return Result;
5278
5279 if (!AR->isAffine())
5280 return Result;
5281
5282 // This function can be expensive, only try to prove NSW once per AddRec.
5283 if (!SignedWrapViaInductionTried.insert(AR).second)
5284 return Result;
5285
5286 const SCEV *Step = AR->getStepRecurrence(*this);
5287 const Loop *L = AR->getLoop();
5288
5289 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5290 // Note that this serves two purposes: It filters out loops that are
5291 // simply not analyzable, and it covers the case where this code is
5292 // being called from within backedge-taken count analysis, such that
5293 // attempting to ask for the backedge-taken count would likely result
5294 // in infinite recursion. In the later case, the analysis code will
5295 // cope with a conservative value, and it will take care to purge
5296 // that value once it has finished.
5297 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5298
5299 // Normally, in the cases we can prove no-overflow via a
5300 // backedge guarding condition, we can also compute a backedge
5301 // taken count for the loop. The exceptions are assumptions and
5302 // guards present in the loop -- SCEV is not great at exploiting
5303 // these to compute max backedge taken counts, but can still use
5304 // these to prove lack of overflow. Use this fact to avoid
5305 // doing extra work that may not pay off.
5306
5307 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5308 AC.assumptions().empty())
5309 return Result;
5310
5311 // If the backedge is guarded by a comparison with the pre-inc value the
5312 // addrec is safe. Also, if the entry is guarded by a comparison with the
5313 // start value and the backedge is guarded by a comparison with the post-inc
5314 // value, the addrec is safe.
5316 const SCEV *OverflowLimit =
5317 getSignedOverflowLimitForStep(Step, &Pred, this);
5318 if (OverflowLimit &&
5319 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5320 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5321 Result = setFlags(Result, SCEV::FlagNSW);
5322 }
5323 return Result;
5324}
5326ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5328
5329 if (AR->hasNoUnsignedWrap())
5330 return Result;
5331
5332 if (!AR->isAffine())
5333 return Result;
5334
5335 // This function can be expensive, only try to prove NUW once per AddRec.
5336 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5337 return Result;
5338
5339 const SCEV *Step = AR->getStepRecurrence(*this);
5340 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5341 const Loop *L = AR->getLoop();
5342
5343 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5344 // Note that this serves two purposes: It filters out loops that are
5345 // simply not analyzable, and it covers the case where this code is
5346 // being called from within backedge-taken count analysis, such that
5347 // attempting to ask for the backedge-taken count would likely result
5348 // in infinite recursion. In the later case, the analysis code will
5349 // cope with a conservative value, and it will take care to purge
5350 // that value once it has finished.
5351 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5352
5353 // Normally, in the cases we can prove no-overflow via a
5354 // backedge guarding condition, we can also compute a backedge
5355 // taken count for the loop. The exceptions are assumptions and
5356 // guards present in the loop -- SCEV is not great at exploiting
5357 // these to compute max backedge taken counts, but can still use
5358 // these to prove lack of overflow. Use this fact to avoid
5359 // doing extra work that may not pay off.
5360
5361 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5362 AC.assumptions().empty())
5363 return Result;
5364
5365 // If the backedge is guarded by a comparison with the pre-inc value the
5366 // addrec is safe. Also, if the entry is guarded by a comparison with the
5367 // start value and the backedge is guarded by a comparison with the post-inc
5368 // value, the addrec is safe.
5369 if (isKnownPositive(Step)) {
5370 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5371 getUnsignedRangeMax(Step));
5374 Result = setFlags(Result, SCEV::FlagNUW);
5375 }
5376 }
5377
5378 return Result;
5379}
5380
5381namespace {
5382
5383/// Represents an abstract binary operation. This may exist as a
5384/// normal instruction or constant expression, or may have been
5385/// derived from an expression tree.
5386struct BinaryOp {
5387 unsigned Opcode;
5388 Value *LHS;
5389 Value *RHS;
5390 bool IsNSW = false;
5391 bool IsNUW = false;
5392
5393 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5394 /// constant expression.
5395 Operator *Op = nullptr;
5396
5397 explicit BinaryOp(Operator *Op)
5398 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5399 Op(Op) {
5400 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5401 IsNSW = OBO->hasNoSignedWrap();
5402 IsNUW = OBO->hasNoUnsignedWrap();
5403 }
5404 }
5405
5406 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5407 bool IsNUW = false)
5408 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5409};
5410
5411} // end anonymous namespace
5412
5413/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5414static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5415 AssumptionCache &AC,
5416 const DominatorTree &DT,
5417 const Instruction *CxtI) {
5418 auto *Op = dyn_cast<Operator>(V);
5419 if (!Op)
5420 return std::nullopt;
5421
5422 // Implementation detail: all the cleverness here should happen without
5423 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5424 // SCEV expressions when possible, and we should not break that.
5425
5426 switch (Op->getOpcode()) {
5427 case Instruction::Add:
5428 case Instruction::Sub:
5429 case Instruction::Mul:
5430 case Instruction::UDiv:
5431 case Instruction::URem:
5432 case Instruction::And:
5433 case Instruction::AShr:
5434 case Instruction::Shl:
5435 return BinaryOp(Op);
5436
5437 case Instruction::Or: {
5438 // Convert or disjoint into add nuw nsw.
5439 if (cast<PossiblyDisjointInst>(Op)->isDisjoint()) {
5440 BinaryOp BinOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5441 /*IsNSW=*/true, /*IsNUW=*/true);
5442 // Keep the reference to the original instruction so that we can later
5443 // check whether it can produce poison value or not.
5444 BinOp.Op = Op;
5445 return BinOp;
5446 }
5447 return BinaryOp(Op);
5448 }
5449
5450 case Instruction::Xor:
5451 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5452 // If the RHS of the xor is a signmask, then this is just an add.
5453 // Instcombine turns add of signmask into xor as a strength reduction step.
5454 if (RHSC->getValue().isSignMask())
5455 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5456 // Binary `xor` is a bit-wise `add`.
5457 if (V->getType()->isIntegerTy(1))
5458 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5459 return BinaryOp(Op);
5460
5461 case Instruction::LShr:
5462 // Turn logical shift right of a constant into a unsigned divide.
5463 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5464 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5465
5466 // If the shift count is not less than the bitwidth, the result of
5467 // the shift is undefined. Don't try to analyze it, because the
5468 // resolution chosen here may differ from the resolution chosen in
5469 // other parts of the compiler.
5470 if (SA->getValue().ult(BitWidth)) {
5471 Constant *X =
5472 ConstantInt::get(SA->getContext(),
5473 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5474 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5475 }
5476 }
5477 return BinaryOp(Op);
5478
5479 case Instruction::ExtractValue: {
5480 auto *EVI = cast<ExtractValueInst>(Op);
5481 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5482 break;
5483
5484 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5485 if (!WO)
5486 break;
5487
5488 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5489 bool Signed = WO->isSigned();
5490 // TODO: Should add nuw/nsw flags for mul as well.
5491 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5492 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5493
5494 // Now that we know that all uses of the arithmetic-result component of
5495 // CI are guarded by the overflow check, we can go ahead and pretend
5496 // that the arithmetic is non-overflowing.
5497 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5498 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5499 }
5500
5501 default:
5502 break;
5503 }
5504
5505 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5506 // semantics as a Sub, return a binary sub expression.
5507 if (auto *II = dyn_cast<IntrinsicInst>(V))
5508 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5509 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5510
5511 return std::nullopt;
5512}
5513
5514/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5515/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5516/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5517/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5518/// follows one of the following patterns:
5519/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5520/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5521/// If the SCEV expression of \p Op conforms with one of the expected patterns
5522/// we return the type of the truncation operation, and indicate whether the
5523/// truncated type should be treated as signed/unsigned by setting
5524/// \p Signed to true/false, respectively.
5525static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5526 bool &Signed, ScalarEvolution &SE) {
5527 // The case where Op == SymbolicPHI (that is, with no type conversions on
5528 // the way) is handled by the regular add recurrence creating logic and
5529 // would have already been triggered in createAddRecForPHI. Reaching it here
5530 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5531 // because one of the other operands of the SCEVAddExpr updating this PHI is
5532 // not invariant).
5533 //
5534 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5535 // this case predicates that allow us to prove that Op == SymbolicPHI will
5536 // be added.
5537 if (Op == SymbolicPHI)
5538 return nullptr;
5539
5540 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5541 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5542 if (SourceBits != NewBits)
5543 return nullptr;
5544
5545 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5546 Signed = true;
5547 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5548 }
5549 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5550 Signed = false;
5551 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5552 }
5553 return nullptr;
5554}
5555
5556static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5557 if (!PN->getType()->isIntegerTy())
5558 return nullptr;
5559 const Loop *L = LI.getLoopFor(PN->getParent());
5560 if (!L || L->getHeader() != PN->getParent())
5561 return nullptr;
5562 return L;
5563}
5564
5565// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5566// computation that updates the phi follows the following pattern:
5567// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5568// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5569// If so, try to see if it can be rewritten as an AddRecExpr under some
5570// Predicates. If successful, return them as a pair. Also cache the results
5571// of the analysis.
5572//
5573// Example usage scenario:
5574// Say the Rewriter is called for the following SCEV:
5575// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5576// where:
5577// %X = phi i64 (%Start, %BEValue)
5578// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5579// and call this function with %SymbolicPHI = %X.
5580//
5581// The analysis will find that the value coming around the backedge has
5582// the following SCEV:
5583// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5584// Upon concluding that this matches the desired pattern, the function
5585// will return the pair {NewAddRec, SmallPredsVec} where:
5586// NewAddRec = {%Start,+,%Step}
5587// SmallPredsVec = {P1, P2, P3} as follows:
5588// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5589// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5590// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5591// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5592// under the predicates {P1,P2,P3}.
5593// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5594// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5595//
5596// TODO's:
5597//
5598// 1) Extend the Induction descriptor to also support inductions that involve
5599// casts: When needed (namely, when we are called in the context of the
5600// vectorizer induction analysis), a Set of cast instructions will be
5601// populated by this method, and provided back to isInductionPHI. This is
5602// needed to allow the vectorizer to properly record them to be ignored by
5603// the cost model and to avoid vectorizing them (otherwise these casts,
5604// which are redundant under the runtime overflow checks, will be
5605// vectorized, which can be costly).
5606//
5607// 2) Support additional induction/PHISCEV patterns: We also want to support
5608// inductions where the sext-trunc / zext-trunc operations (partly) occur
5609// after the induction update operation (the induction increment):
5610//
5611// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5612// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5613//
5614// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5615// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5616//
5617// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5618std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5619ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5621
5622 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5623 // return an AddRec expression under some predicate.
5624
5625 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5626 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5627 assert(L && "Expecting an integer loop header phi");
5628
5629 // The loop may have multiple entrances or multiple exits; we can analyze
5630 // this phi as an addrec if it has a unique entry value and a unique
5631 // backedge value.
5632 Value *BEValueV = nullptr, *StartValueV = nullptr;
5633 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5634 Value *V = PN->getIncomingValue(i);
5635 if (L->contains(PN->getIncomingBlock(i))) {
5636 if (!BEValueV) {
5637 BEValueV = V;
5638 } else if (BEValueV != V) {
5639 BEValueV = nullptr;
5640 break;
5641 }
5642 } else if (!StartValueV) {
5643 StartValueV = V;
5644 } else if (StartValueV != V) {
5645 StartValueV = nullptr;
5646 break;
5647 }
5648 }
5649 if (!BEValueV || !StartValueV)
5650 return std::nullopt;
5651
5652 const SCEV *BEValue = getSCEV(BEValueV);
5653
5654 // If the value coming around the backedge is an add with the symbolic
5655 // value we just inserted, possibly with casts that we can ignore under
5656 // an appropriate runtime guard, then we found a simple induction variable!
5657 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5658 if (!Add)
5659 return std::nullopt;
5660
5661 // If there is a single occurrence of the symbolic value, possibly
5662 // casted, replace it with a recurrence.
5663 unsigned FoundIndex = Add->getNumOperands();
5664 Type *TruncTy = nullptr;
5665 bool Signed;
5666 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5667 if ((TruncTy =
5668 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5669 if (FoundIndex == e) {
5670 FoundIndex = i;
5671 break;
5672 }
5673
5674 if (FoundIndex == Add->getNumOperands())
5675 return std::nullopt;
5676
5677 // Create an add with everything but the specified operand.
5679 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5680 if (i != FoundIndex)
5681 Ops.push_back(Add->getOperand(i));
5682 const SCEV *Accum = getAddExpr(Ops);
5683
5684 // The runtime checks will not be valid if the step amount is
5685 // varying inside the loop.
5686 if (!isLoopInvariant(Accum, L))
5687 return std::nullopt;
5688
5689 // *** Part2: Create the predicates
5690
5691 // Analysis was successful: we have a phi-with-cast pattern for which we
5692 // can return an AddRec expression under the following predicates:
5693 //
5694 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5695 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5696 // P2: An Equal predicate that guarantees that
5697 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5698 // P3: An Equal predicate that guarantees that
5699 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5700 //
5701 // As we next prove, the above predicates guarantee that:
5702 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5703 //
5704 //
5705 // More formally, we want to prove that:
5706 // Expr(i+1) = Start + (i+1) * Accum
5707 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5708 //
5709 // Given that:
5710 // 1) Expr(0) = Start
5711 // 2) Expr(1) = Start + Accum
5712 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5713 // 3) Induction hypothesis (step i):
5714 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5715 //
5716 // Proof:
5717 // Expr(i+1) =
5718 // = Start + (i+1)*Accum
5719 // = (Start + i*Accum) + Accum
5720 // = Expr(i) + Accum
5721 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5722 // :: from step i
5723 //
5724 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5725 //
5726 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5727 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5728 // + Accum :: from P3
5729 //
5730 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5731 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5732 //
5733 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5734 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5735 //
5736 // By induction, the same applies to all iterations 1<=i<n:
5737 //
5738
5739 // Create a truncated addrec for which we will add a no overflow check (P1).
5740 const SCEV *StartVal = getSCEV(StartValueV);
5741 const SCEV *PHISCEV =
5742 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5743 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5744
5745 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5746 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5747 // will be constant.
5748 //
5749 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5750 // add P1.
5751 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5755 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5756 Predicates.push_back(AddRecPred);
5757 }
5758
5759 // Create the Equal Predicates P2,P3:
5760
5761 // It is possible that the predicates P2 and/or P3 are computable at
5762 // compile time due to StartVal and/or Accum being constants.
5763 // If either one is, then we can check that now and escape if either P2
5764 // or P3 is false.
5765
5766 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5767 // for each of StartVal and Accum
5768 auto getExtendedExpr = [&](const SCEV *Expr,
5769 bool CreateSignExtend) -> const SCEV * {
5770 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5771 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5772 const SCEV *ExtendedExpr =
5773 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5774 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5775 return ExtendedExpr;
5776 };
5777
5778 // Given:
5779 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5780 // = getExtendedExpr(Expr)
5781 // Determine whether the predicate P: Expr == ExtendedExpr
5782 // is known to be false at compile time
5783 auto PredIsKnownFalse = [&](const SCEV *Expr,
5784 const SCEV *ExtendedExpr) -> bool {
5785 return Expr != ExtendedExpr &&
5786 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5787 };
5788
5789 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5790 if (PredIsKnownFalse(StartVal, StartExtended)) {
5791 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5792 return std::nullopt;
5793 }
5794
5795 // The Step is always Signed (because the overflow checks are either
5796 // NSSW or NUSW)
5797 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5798 if (PredIsKnownFalse(Accum, AccumExtended)) {
5799 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5800 return std::nullopt;
5801 }
5802
5803 auto AppendPredicate = [&](const SCEV *Expr,
5804 const SCEV *ExtendedExpr) -> void {
5805 if (Expr != ExtendedExpr &&
5806 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5807 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5808 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5809 Predicates.push_back(Pred);
5810 }
5811 };
5812
5813 AppendPredicate(StartVal, StartExtended);
5814 AppendPredicate(Accum, AccumExtended);
5815
5816 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5817 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5818 // into NewAR if it will also add the runtime overflow checks specified in
5819 // Predicates.
5820 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5821
5822 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5823 std::make_pair(NewAR, Predicates);
5824 // Remember the result of the analysis for this SCEV at this locayyytion.
5825 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5826 return PredRewrite;
5827}
5828
5829std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5831 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5832 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5833 if (!L)
5834 return std::nullopt;
5835
5836 // Check to see if we already analyzed this PHI.
5837 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5838 if (I != PredicatedSCEVRewrites.end()) {
5839 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5840 I->second;
5841 // Analysis was done before and failed to create an AddRec:
5842 if (Rewrite.first == SymbolicPHI)
5843 return std::nullopt;
5844 // Analysis was done before and succeeded to create an AddRec under
5845 // a predicate:
5846 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5847 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5848 return Rewrite;
5849 }
5850
5851 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5852 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5853
5854 // Record in the cache that the analysis failed
5855 if (!Rewrite) {
5857 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5858 return std::nullopt;
5859 }
5860
5861 return Rewrite;
5862}
5863
5864// FIXME: This utility is currently required because the Rewriter currently
5865// does not rewrite this expression:
5866// {0, +, (sext ix (trunc iy to ix) to iy)}
5867// into {0, +, %step},
5868// even when the following Equal predicate exists:
5869// "%step == (sext ix (trunc iy to ix) to iy)".
5871 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2,
5872 ArrayRef<const SCEVPredicate *> NoWrapPreds) const {
5873 if (AR1 == AR2)
5874 return true;
5875
5876 SCEVUnionPredicate NoWrapUnionPred(NoWrapPreds, SE);
5877 SCEVUnionPredicate AllPreds = Preds->getUnionWith(&NoWrapUnionPred, SE);
5878 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5879 if (Expr1 != Expr2 &&
5880 !AllPreds.implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5881 !AllPreds.implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5882 return false;
5883 return true;
5884 };
5885
5886 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5887 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5888 return false;
5889 return true;
5890}
5891
5892/// A helper function for createAddRecFromPHI to handle simple cases.
5893///
5894/// This function tries to find an AddRec expression for the simplest (yet most
5895/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5896/// If it fails, createAddRecFromPHI will use a more general, but slow,
5897/// technique for finding the AddRec expression.
5898const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5899 Value *BEValueV,
5900 Value *StartValueV) {
5901 const Loop *L = LI.getLoopFor(PN->getParent());
5902 assert(L && L->getHeader() == PN->getParent());
5903 assert(BEValueV && StartValueV);
5904
5905 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5906 if (!BO)
5907 return nullptr;
5908
5909 if (BO->Opcode != Instruction::Add)
5910 return nullptr;
5911
5912 const SCEV *Accum = nullptr;
5913 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5914 Accum = getSCEV(BO->RHS);
5915 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5916 Accum = getSCEV(BO->LHS);
5917
5918 if (!Accum)
5919 return nullptr;
5920
5922 if (BO->IsNUW)
5923 Flags = setFlags(Flags, SCEV::FlagNUW);
5924 if (BO->IsNSW)
5925 Flags = setFlags(Flags, SCEV::FlagNSW);
5926
5927 const SCEV *StartVal = getSCEV(StartValueV);
5928 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5929 insertValueToMap(PN, PHISCEV);
5930
5931 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV))
5932 inferNoWrapViaConstantRanges(AR);
5933
5934 // We can add Flags to the post-inc expression only if we
5935 // know that it is *undefined behavior* for BEValueV to
5936 // overflow.
5937 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5938 assert(isLoopInvariant(Accum, L) &&
5939 "Accum is defined outside L, but is not invariant?");
5940 if (isAddRecNeverPoison(BEInst, L))
5941 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5942 }
5943
5944 return PHISCEV;
5945}
5946
5947const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5948 const Loop *L = LI.getLoopFor(PN->getParent());
5949 if (!L || L->getHeader() != PN->getParent())
5950 return nullptr;
5951
5952 // The loop may have multiple entrances or multiple exits; we can analyze
5953 // this phi as an addrec if it has a unique entry value and a unique
5954 // backedge value.
5955 Value *BEValueV = nullptr, *StartValueV = nullptr;
5956 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5957 Value *V = PN->getIncomingValue(i);
5958 if (L->contains(PN->getIncomingBlock(i))) {
5959 if (!BEValueV) {
5960 BEValueV = V;
5961 } else if (BEValueV != V) {
5962 BEValueV = nullptr;
5963 break;
5964 }
5965 } else if (!StartValueV) {
5966 StartValueV = V;
5967 } else if (StartValueV != V) {
5968 StartValueV = nullptr;
5969 break;
5970 }
5971 }
5972 if (!BEValueV || !StartValueV)
5973 return nullptr;
5974
5975 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5976 "PHI node already processed?");
5977
5978 // First, try to find AddRec expression without creating a fictituos symbolic
5979 // value for PN.
5980 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5981 return S;
5982
5983 // Handle PHI node value symbolically.
5984 const SCEV *SymbolicName = getUnknown(PN);
5985 insertValueToMap(PN, SymbolicName);
5986
5987 // Using this symbolic name for the PHI, analyze the value coming around
5988 // the back-edge.
5989 const SCEV *BEValue = getSCEV(BEValueV);
5990
5991 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5992 // has a special value for the first iteration of the loop.
5993
5994 // If the value coming around the backedge is an add with the symbolic
5995 // value we just inserted, then we found a simple induction variable!
5996 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5997 // If there is a single occurrence of the symbolic value, replace it
5998 // with a recurrence.
5999 unsigned FoundIndex = Add->getNumOperands();
6000 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6001 if (Add->getOperand(i) == SymbolicName)
6002 if (FoundIndex == e) {
6003 FoundIndex = i;
6004 break;
6005 }
6006
6007 if (FoundIndex != Add->getNumOperands()) {
6008 // Create an add with everything but the specified operand.
6010 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6011 if (i != FoundIndex)
6012 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6013 L, *this));
6014 const SCEV *Accum = getAddExpr(Ops);
6015
6016 // This is not a valid addrec if the step amount is varying each
6017 // loop iteration, but is not itself an addrec in this loop.
6018 if (isLoopInvariant(Accum, L) ||
6019 (isa<SCEVAddRecExpr>(Accum) &&
6020 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6022
6023 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6024 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6025 if (BO->IsNUW)
6026 Flags = setFlags(Flags, SCEV::FlagNUW);
6027 if (BO->IsNSW)
6028 Flags = setFlags(Flags, SCEV::FlagNSW);
6029 }
6030 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6031 if (GEP->getOperand(0) == PN) {
6032 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6033 // If the increment has any nowrap flags, then we know the address
6034 // space cannot be wrapped around.
6035 if (NW != GEPNoWrapFlags::none())
6036 Flags = setFlags(Flags, SCEV::FlagNW);
6037 // If the GEP is nuw or nusw with non-negative offset, we know that
6038 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6039 // offset is treated as signed, while the base is unsigned.
6040 if (NW.hasNoUnsignedWrap() ||
6042 Flags = setFlags(Flags, SCEV::FlagNUW);
6043 }
6044
6045 // We cannot transfer nuw and nsw flags from subtraction
6046 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6047 // for instance.
6048 }
6049
6050 const SCEV *StartVal = getSCEV(StartValueV);
6051 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6052
6053 // Okay, for the entire analysis of this edge we assumed the PHI
6054 // to be symbolic. We now need to go back and purge all of the
6055 // entries for the scalars that use the symbolic expression.
6056 forgetMemoizedResults({SymbolicName});
6057 insertValueToMap(PN, PHISCEV);
6058
6059 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV))
6060 inferNoWrapViaConstantRanges(AR);
6061
6062 // We can add Flags to the post-inc expression only if we
6063 // know that it is *undefined behavior* for BEValueV to
6064 // overflow.
6065 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6066 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6067 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6068
6069 return PHISCEV;
6070 }
6071 }
6072 } else {
6073 // Otherwise, this could be a loop like this:
6074 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6075 // In this case, j = {1,+,1} and BEValue is j.
6076 // Because the other in-value of i (0) fits the evolution of BEValue
6077 // i really is an addrec evolution.
6078 //
6079 // We can generalize this saying that i is the shifted value of BEValue
6080 // by one iteration:
6081 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6082
6083 // Do not allow refinement in rewriting of BEValue.
6084 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6085 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6086 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6087 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6088 const SCEV *StartVal = getSCEV(StartValueV);
6089 if (Start == StartVal) {
6090 // Okay, for the entire analysis of this edge we assumed the PHI
6091 // to be symbolic. We now need to go back and purge all of the
6092 // entries for the scalars that use the symbolic expression.
6093 forgetMemoizedResults({SymbolicName});
6094 insertValueToMap(PN, Shifted);
6095 return Shifted;
6096 }
6097 }
6098 }
6099
6100 // Remove the temporary PHI node SCEV that has been inserted while intending
6101 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6102 // as it will prevent later (possibly simpler) SCEV expressions to be added
6103 // to the ValueExprMap.
6104 eraseValueFromMap(PN);
6105
6106 return nullptr;
6107}
6108
6109// Try to match a control flow sequence that branches out at BI and merges back
6110// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6111// match.
6113 Value *&C, Value *&LHS, Value *&RHS) {
6114 C = BI->getCondition();
6115
6116 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6117 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6118
6119 Use &LeftUse = Merge->getOperandUse(0);
6120 Use &RightUse = Merge->getOperandUse(1);
6121
6122 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6123 LHS = LeftUse;
6124 RHS = RightUse;
6125 return true;
6126 }
6127
6128 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6129 LHS = RightUse;
6130 RHS = LeftUse;
6131 return true;
6132 }
6133
6134 return false;
6135}
6136
6138 Value *&Cond, Value *&LHS,
6139 Value *&RHS) {
6140 auto IsReachable =
6141 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6142 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6143 // Try to match
6144 //
6145 // br %cond, label %left, label %right
6146 // left:
6147 // br label %merge
6148 // right:
6149 // br label %merge
6150 // merge:
6151 // V = phi [ %x, %left ], [ %y, %right ]
6152 //
6153 // as "select %cond, %x, %y"
6154
6155 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6156 assert(IDom && "At least the entry block should dominate PN");
6157
6158 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6159 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6160 }
6161 return false;
6162}
6163
6164const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6165 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6166 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6169 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6170
6171 return nullptr;
6172}
6173
6175 BinaryOperator *CommonInst = nullptr;
6176 // Check if instructions are identical.
6177 for (Value *Incoming : PN->incoming_values()) {
6178 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6179 if (!IncomingInst)
6180 return nullptr;
6181 if (CommonInst) {
6182 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6183 return nullptr; // Not identical, give up
6184 } else {
6185 // Remember binary operator
6186 CommonInst = IncomingInst;
6187 }
6188 }
6189 return CommonInst;
6190}
6191
6192/// Returns SCEV for the first operand of a phi if all phi operands have
6193/// identical opcodes and operands
6194/// eg.
6195/// a: %add = %a + %b
6196/// br %c
6197/// b: %add1 = %a + %b
6198/// br %c
6199/// c: %phi = phi [%add, a], [%add1, b]
6200/// scev(%phi) => scev(%add)
6201const SCEV *
6202ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6203 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6204 if (!CommonInst)
6205 return nullptr;
6206
6207 // Check if SCEV exprs for instructions are identical.
6208 const SCEV *CommonSCEV = getSCEV(CommonInst);
6209 bool SCEVExprsIdentical =
6211 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6212 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6213}
6214
6215const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6216 if (const SCEV *S = createAddRecFromPHI(PN))
6217 return S;
6218
6219 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6220 // phi node for X.
6221 if (Value *V = simplifyInstruction(
6222 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6223 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6224 return getSCEV(V);
6225
6226 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6227 return S;
6228
6229 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6230 return S;
6231
6232 // If it's not a loop phi, we can't handle it yet.
6233 return getUnknown(PN);
6234}
6235
6236bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6237 SCEVTypes RootKind) {
6238 struct FindClosure {
6239 const SCEV *OperandToFind;
6240 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6241 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6242
6243 bool Found = false;
6244
6245 bool canRecurseInto(SCEVTypes Kind) const {
6246 // We can only recurse into the SCEV expression of the same effective type
6247 // as the type of our root SCEV expression, and into zero-extensions.
6248 return RootKind == Kind || NonSequentialRootKind == Kind ||
6249 scZeroExtend == Kind;
6250 };
6251
6252 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6253 : OperandToFind(OperandToFind), RootKind(RootKind),
6254 NonSequentialRootKind(
6256 RootKind)) {}
6257
6258 bool follow(const SCEV *S) {
6259 Found = S == OperandToFind;
6260
6261 return !isDone() && canRecurseInto(S->getSCEVType());
6262 }
6263
6264 bool isDone() const { return Found; }
6265 };
6266
6267 FindClosure FC(OperandToFind, RootKind);
6268 visitAll(Root, FC);
6269 return FC.Found;
6270}
6271
6272std::optional<const SCEV *>
6273ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6274 ICmpInst *Cond,
6275 Value *TrueVal,
6276 Value *FalseVal) {
6277 // Try to match some simple smax or umax patterns.
6278 auto *ICI = Cond;
6279
6280 Value *LHS = ICI->getOperand(0);
6281 Value *RHS = ICI->getOperand(1);
6282
6283 switch (ICI->getPredicate()) {
6284 case ICmpInst::ICMP_SLT:
6285 case ICmpInst::ICMP_SLE:
6286 case ICmpInst::ICMP_ULT:
6287 case ICmpInst::ICMP_ULE:
6288 std::swap(LHS, RHS);
6289 [[fallthrough]];
6290 case ICmpInst::ICMP_SGT:
6291 case ICmpInst::ICMP_SGE:
6292 case ICmpInst::ICMP_UGT:
6293 case ICmpInst::ICMP_UGE:
6294 // a > b ? a+x : b+x -> max(a, b)+x
6295 // a > b ? b+x : a+x -> min(a, b)+x
6297 bool Signed = ICI->isSigned();
6298 const SCEV *LA = getSCEV(TrueVal);
6299 const SCEV *RA = getSCEV(FalseVal);
6300 const SCEV *LS = getSCEV(LHS);
6301 const SCEV *RS = getSCEV(RHS);
6302 if (LA->getType()->isPointerTy()) {
6303 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6304 // Need to make sure we can't produce weird expressions involving
6305 // negated pointers.
6306 if (LA == LS && RA == RS)
6307 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6308 if (LA == RS && RA == LS)
6309 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6310 }
6311 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6312 if (Op->getType()->isPointerTy()) {
6315 return Op;
6316 }
6317 if (Signed)
6318 Op = getNoopOrSignExtend(Op, Ty);
6319 else
6320 Op = getNoopOrZeroExtend(Op, Ty);
6321 return Op;
6322 };
6323 LS = CoerceOperand(LS);
6324 RS = CoerceOperand(RS);
6326 break;
6327 const SCEV *LDiff = getMinusSCEV(LA, LS);
6328 const SCEV *RDiff = getMinusSCEV(RA, RS);
6329 if (LDiff == RDiff)
6330 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6331 LDiff);
6332 LDiff = getMinusSCEV(LA, RS);
6333 RDiff = getMinusSCEV(RA, LS);
6334 if (LDiff == RDiff)
6335 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6336 LDiff);
6337 }
6338 break;
6339 case ICmpInst::ICMP_NE:
6340 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6341 std::swap(TrueVal, FalseVal);
6342 [[fallthrough]];
6343 case ICmpInst::ICMP_EQ:
6344 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6347 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6348 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6349 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6350 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6351 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6352 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6353 return getAddExpr(getUMaxExpr(X, C), Y);
6354 }
6355 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6356 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6357 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6358 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6360 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6361 const SCEV *X = getSCEV(LHS);
6362 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6363 X = ZExt->getOperand();
6364 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6365 const SCEV *FalseValExpr = getSCEV(FalseVal);
6366 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6367 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6368 /*Sequential=*/true);
6369 }
6370 }
6371 break;
6372 default:
6373 break;
6374 }
6375
6376 return std::nullopt;
6377}
6378
6379static std::optional<const SCEV *>
6381 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6382 assert(CondExpr->getType()->isIntegerTy(1) &&
6383 TrueExpr->getType() == FalseExpr->getType() &&
6384 TrueExpr->getType()->isIntegerTy(1) &&
6385 "Unexpected operands of a select.");
6386
6387 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6388 // --> C + (umin_seq cond, x - C)
6389 //
6390 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6391 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6392 // --> C + (umin_seq ~cond, x - C)
6393
6394 // FIXME: while we can't legally model the case where both of the hands
6395 // are fully variable, we only require that the *difference* is constant.
6396 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6397 return std::nullopt;
6398
6399 const SCEV *X, *C;
6400 if (isa<SCEVConstant>(TrueExpr)) {
6401 CondExpr = SE->getNotSCEV(CondExpr);
6402 X = FalseExpr;
6403 C = TrueExpr;
6404 } else {
6405 X = TrueExpr;
6406 C = FalseExpr;
6407 }
6408 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6409 /*Sequential=*/true));
6410}
6411
6412static std::optional<const SCEV *>
6414 Value *FalseVal) {
6415 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6416 return std::nullopt;
6417
6418 const auto *SECond = SE->getSCEV(Cond);
6419 const auto *SETrue = SE->getSCEV(TrueVal);
6420 const auto *SEFalse = SE->getSCEV(FalseVal);
6421 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6422}
6423
6424const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6425 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6426 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6427 assert(TrueVal->getType() == FalseVal->getType() &&
6428 V->getType() == TrueVal->getType() &&
6429 "Types of select hands and of the result must match.");
6430
6431 // For now, only deal with i1-typed `select`s.
6432 if (!V->getType()->isIntegerTy(1))
6433 return getUnknown(V);
6434
6435 if (std::optional<const SCEV *> S =
6436 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6437 return *S;
6438
6439 return getUnknown(V);
6440}
6441
6442const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6443 Value *TrueVal,
6444 Value *FalseVal) {
6445 // Handle "constant" branch or select. This can occur for instance when a
6446 // loop pass transforms an inner loop and moves on to process the outer loop.
6447 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6448 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6449
6450 if (auto *I = dyn_cast<Instruction>(V)) {
6451 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6452 if (std::optional<const SCEV *> S =
6453 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6454 TrueVal, FalseVal))
6455 return *S;
6456 }
6457 }
6458
6459 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6460}
6461
6462/// Expand GEP instructions into add and multiply operations. This allows them
6463/// to be analyzed by regular SCEV code.
6464const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6465 assert(GEP->getSourceElementType()->isSized() &&
6466 "GEP source element type must be sized");
6467
6468 SmallVector<SCEVUse, 4> IndexExprs;
6469 for (Value *Index : GEP->indices())
6470 IndexExprs.push_back(getSCEV(Index));
6471 return getGEPExpr(GEP, IndexExprs);
6472}
6473
6474APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6475 const Instruction *CtxI) {
6476 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6477 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6478 return TrailingZeros >= BitWidth
6480 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6481 };
6482 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6483 // The result is GCD of all operands results.
6484 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6485 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6487 Res, getConstantMultiple(N->getOperand(I), CtxI));
6488 return Res;
6489 };
6490
6491 switch (S->getSCEVType()) {
6492 case scConstant:
6493 return cast<SCEVConstant>(S)->getAPInt();
6494 case scPtrToAddr:
6495 case scPtrToInt:
6496 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6497 case scUDivExpr:
6498 case scVScale:
6499 return APInt(BitWidth, 1);
6500 case scTruncate: {
6501 // Only multiples that are a power of 2 will hold after truncation.
6502 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6503 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6504 return GetShiftedByZeros(TZ);
6505 }
6506 case scZeroExtend: {
6507 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6508 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6509 }
6510 case scSignExtend: {
6511 // Only multiples that are a power of 2 will hold after sext.
6512 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6513 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6514 return GetShiftedByZeros(TZ);
6515 }
6516 case scMulExpr: {
6517 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6518 if (M->hasNoUnsignedWrap()) {
6519 // The result is the product of all operand results.
6520 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6521 for (const SCEV *Operand : M->operands().drop_front())
6522 Res = Res * getConstantMultiple(Operand, CtxI);
6523 return Res;
6524 }
6525
6526 // If there are no wrap guarentees, find the trailing zeros, which is the
6527 // sum of trailing zeros for all its operands.
6528 uint32_t TZ = 0;
6529 for (const SCEV *Operand : M->operands())
6530 TZ += getMinTrailingZeros(Operand, CtxI);
6531 return GetShiftedByZeros(TZ);
6532 }
6533 case scAddExpr:
6534 case scAddRecExpr: {
6535 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6536 if (N->hasNoUnsignedWrap())
6537 return GetGCDMultiple(N);
6538 // Find the trailing bits, which is the minimum of its operands.
6539 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6540 for (const SCEV *Operand : N->operands().drop_front())
6541 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6542 return GetShiftedByZeros(TZ);
6543 }
6544 case scUMaxExpr:
6545 case scSMaxExpr:
6546 case scUMinExpr:
6547 case scSMinExpr:
6549 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6550 case scUnknown: {
6551 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6552 // the point their underlying IR instruction has been defined. If CtxI was
6553 // not provided, use:
6554 // * the first instruction in the entry block if it is an argument
6555 // * the instruction itself otherwise.
6556 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6557 if (!CtxI) {
6558 if (isa<Argument>(U->getValue()))
6559 CtxI = &*F.getEntryBlock().begin();
6560 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6561 CtxI = I;
6562 }
6563 unsigned Known =
6564 computeKnownBits(U->getValue(),
6565 SimplifyQuery(getDataLayout(), &DT, &AC, CtxI)
6566 .allowEphemerals(true))
6567 .countMinTrailingZeros();
6568 return GetShiftedByZeros(Known);
6569 }
6570 case scCouldNotCompute:
6571 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6572 }
6573 llvm_unreachable("Unknown SCEV kind!");
6574}
6575
6577 const Instruction *CtxI) {
6578 // Skip looking up and updating the cache if there is a context instruction,
6579 // as the result will only be valid in the specified context.
6580 if (CtxI)
6581 return getConstantMultipleImpl(S, CtxI);
6582
6583 auto I = ConstantMultipleCache.find(S);
6584 if (I != ConstantMultipleCache.end())
6585 return I->second;
6586
6587 APInt Result = getConstantMultipleImpl(S, CtxI);
6588 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6589 assert(InsertPair.second && "Should insert a new key");
6590 return InsertPair.first->second;
6591}
6592
6594 APInt Multiple = getConstantMultiple(S);
6595 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6596}
6597
6599 const Instruction *CtxI) {
6600 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6601 (unsigned)getTypeSizeInBits(S->getType()));
6602}
6603
6604/// Helper method to assign a range to V from metadata present in the IR.
6605static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6607 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6608 return getConstantRangeFromMetadata(*MD);
6609 if (const auto *CB = dyn_cast<CallBase>(V))
6610 if (std::optional<ConstantRange> Range = CB->getRange())
6611 return Range;
6612 }
6613 if (auto *A = dyn_cast<Argument>(V))
6614 if (std::optional<ConstantRange> Range = A->getRange())
6615 return Range;
6616
6617 return std::nullopt;
6618}
6619
6621 SCEV::NoWrapFlags Flags) {
6622 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6623 AddRec->setNoWrapFlags(Flags);
6624 UnsignedRanges.erase(AddRec);
6625 SignedRanges.erase(AddRec);
6626 ConstantMultipleCache.erase(AddRec);
6627 }
6628}
6629
6630ConstantRange ScalarEvolution::
6631getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6632 const DataLayout &DL = getDataLayout();
6633
6634 unsigned BitWidth = getTypeSizeInBits(U->getType());
6635 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6636
6637 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6638 // use information about the trip count to improve our available range. Note
6639 // that the trip count independent cases are already handled by known bits.
6640 // WARNING: The definition of recurrence used here is subtly different than
6641 // the one used by AddRec (and thus most of this file). Step is allowed to
6642 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6643 // and other addrecs in the same loop (for non-affine addrecs). The code
6644 // below intentionally handles the case where step is not loop invariant.
6645 auto *P = dyn_cast<PHINode>(U->getValue());
6646 if (!P)
6647 return FullSet;
6648
6649 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6650 // even the values that are not available in these blocks may come from them,
6651 // and this leads to false-positive recurrence test.
6652 for (auto *Pred : predecessors(P->getParent()))
6653 if (!DT.isReachableFromEntry(Pred))
6654 return FullSet;
6655
6656 BinaryOperator *BO;
6657 Value *Start, *Step;
6658 if (!matchSimpleRecurrence(P, BO, Start, Step))
6659 return FullSet;
6660
6661 // If we found a recurrence in reachable code, we must be in a loop. Note
6662 // that BO might be in some subloop of L, and that's completely okay.
6663 auto *L = LI.getLoopFor(P->getParent());
6664 assert(L && L->getHeader() == P->getParent());
6665 if (!L->contains(BO->getParent()))
6666 // NOTE: This bailout should be an assert instead. However, asserting
6667 // the condition here exposes a case where LoopFusion is querying SCEV
6668 // with malformed loop information during the midst of the transform.
6669 // There doesn't appear to be an obvious fix, so for the moment bailout
6670 // until the caller issue can be fixed. PR49566 tracks the bug.
6671 return FullSet;
6672
6673 // TODO: Extend to other opcodes such as mul, and div
6674 switch (BO->getOpcode()) {
6675 default:
6676 return FullSet;
6677 case Instruction::AShr:
6678 case Instruction::LShr:
6679 case Instruction::Shl:
6680 break;
6681 };
6682
6683 if (BO->getOperand(0) != P)
6684 // TODO: Handle the power function forms some day.
6685 return FullSet;
6686
6687 unsigned TC = getSmallConstantMaxTripCount(L);
6688 if (!TC || TC >= BitWidth)
6689 return FullSet;
6690
6691 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6692 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6693 assert(KnownStart.getBitWidth() == BitWidth &&
6694 KnownStep.getBitWidth() == BitWidth);
6695
6696 // Compute total shift amount, being careful of overflow and bitwidths.
6697 auto MaxShiftAmt = KnownStep.getMaxValue();
6698 APInt TCAP(BitWidth, TC-1);
6699 bool Overflow = false;
6700 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6701 if (Overflow)
6702 return FullSet;
6703
6704 switch (BO->getOpcode()) {
6705 default:
6706 llvm_unreachable("filtered out above");
6707 case Instruction::AShr: {
6708 // For each ashr, three cases:
6709 // shift = 0 => unchanged value
6710 // saturation => 0 or -1
6711 // other => a value closer to zero (of the same sign)
6712 // Thus, the end value is closer to zero than the start.
6713 auto KnownEnd = KnownBits::ashr(KnownStart,
6714 KnownBits::makeConstant(TotalShift));
6715 if (KnownStart.isNonNegative())
6716 // Analogous to lshr (simply not yet canonicalized)
6717 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6718 KnownStart.getMaxValue() + 1);
6719 if (KnownStart.isNegative())
6720 // End >=u Start && End <=s Start
6721 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6722 KnownEnd.getMaxValue() + 1);
6723 break;
6724 }
6725 case Instruction::LShr: {
6726 // For each lshr, three cases:
6727 // shift = 0 => unchanged value
6728 // saturation => 0
6729 // other => a smaller positive number
6730 // Thus, the low end of the unsigned range is the last value produced.
6731 auto KnownEnd = KnownBits::lshr(KnownStart,
6732 KnownBits::makeConstant(TotalShift));
6733 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6734 KnownStart.getMaxValue() + 1);
6735 }
6736 case Instruction::Shl: {
6737 // Iff no bits are shifted out, value increases on every shift.
6738 auto KnownEnd = KnownBits::shl(KnownStart,
6739 KnownBits::makeConstant(TotalShift));
6740 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6741 return ConstantRange(KnownStart.getMinValue(),
6742 KnownEnd.getMaxValue() + 1);
6743 break;
6744 }
6745 };
6746 return FullSet;
6747}
6748
6749// The goal of this function is to check if recursively visiting the operands
6750// of this PHI might lead to an infinite loop. If we do see such a loop,
6751// there's no good way to break it, so we avoid analyzing such cases.
6752//
6753// getRangeRef previously used a visited set to avoid infinite loops, but this
6754// caused other issues: the result was dependent on the order of getRangeRef
6755// calls, and the interaction with createSCEVIter could cause a stack overflow
6756// in some cases (see issue #148253).
6757//
6758// FIXME: The way this is implemented is overly conservative; this checks
6759// for a few obviously safe patterns, but anything that doesn't lead to
6760// recursion is fine.
6762 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6764 return true;
6765
6766 if (all_of(PHI->operands(),
6767 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6768 return true;
6769
6770 return false;
6771}
6772
6773const ConstantRange &
6774ScalarEvolution::getRangeRefIter(const SCEV *S,
6775 ScalarEvolution::RangeSignHint SignHint) {
6776 DenseMap<const SCEV *, ConstantRange> &Cache =
6777 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6778 : SignedRanges;
6779 SmallVector<SCEVUse> WorkList;
6780 SmallPtrSet<const SCEV *, 8> Seen;
6781
6782 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6783 // SCEVUnknown PHI node.
6784 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6785 if (!Seen.insert(Expr).second)
6786 return;
6787 if (Cache.contains(Expr))
6788 return;
6789 switch (Expr->getSCEVType()) {
6790 case scUnknown:
6792 break;
6793 [[fallthrough]];
6794 case scConstant:
6795 case scVScale:
6796 case scTruncate:
6797 case scZeroExtend:
6798 case scSignExtend:
6799 case scPtrToAddr:
6800 case scPtrToInt:
6801 case scAddExpr:
6802 case scMulExpr:
6803 case scUDivExpr:
6804 case scAddRecExpr:
6805 case scUMaxExpr:
6806 case scSMaxExpr:
6807 case scUMinExpr:
6808 case scSMinExpr:
6810 WorkList.push_back(Expr);
6811 break;
6812 case scCouldNotCompute:
6813 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6814 }
6815 };
6816 AddToWorklist(S);
6817
6818 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6819 for (unsigned I = 0; I != WorkList.size(); ++I) {
6820 const SCEV *P = WorkList[I];
6821 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6822 // If it is not a `SCEVUnknown`, just recurse into operands.
6823 if (!UnknownS) {
6824 for (const SCEV *Op : P->operands())
6825 AddToWorklist(Op);
6826 continue;
6827 }
6828 // `SCEVUnknown`'s require special treatment.
6829 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6830 if (!RangeRefPHIAllowedOperands(DT, P))
6831 continue;
6832 for (auto &Op : reverse(P->operands()))
6833 AddToWorklist(getSCEV(Op));
6834 }
6835 }
6836
6837 if (!WorkList.empty()) {
6838 // Use getRangeRef to compute ranges for items in the worklist in reverse
6839 // order. This will force ranges for earlier operands to be computed before
6840 // their users in most cases.
6841 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6842 getRangeRef(P, SignHint);
6843 }
6844 }
6845
6846 return getRangeRef(S, SignHint, 0);
6847}
6848
6849/// Determine the range for a particular SCEV. If SignHint is
6850/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6851/// with a "cleaner" unsigned (resp. signed) representation.
6852const ConstantRange &ScalarEvolution::getRangeRef(
6853 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6854 DenseMap<const SCEV *, ConstantRange> &Cache =
6855 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6856 : SignedRanges;
6858 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6860
6861 // See if we've computed this range already.
6862 auto I = Cache.find(S);
6863 if (I != Cache.end())
6864 return I->second;
6865
6866 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6867 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6868
6869 // Switch to iteratively computing the range for S, if it is part of a deeply
6870 // nested expression.
6872 return getRangeRefIter(S, SignHint);
6873
6874 unsigned BitWidth = getTypeSizeInBits(S->getType());
6875 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6876 using OBO = OverflowingBinaryOperator;
6877
6878 // If the value has known zeros, the maximum value will have those known zeros
6879 // as well.
6880 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6881 APInt Multiple = getNonZeroConstantMultiple(S);
6882 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6883 if (!Remainder.isZero())
6884 ConservativeResult =
6885 ConstantRange(APInt::getMinValue(BitWidth),
6886 APInt::getMaxValue(BitWidth) - Remainder + 1);
6887 }
6888 else {
6889 uint32_t TZ = getMinTrailingZeros(S);
6890 if (TZ != 0) {
6891 ConservativeResult = ConstantRange(
6893 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6894 }
6895 }
6896
6897 switch (S->getSCEVType()) {
6898 case scConstant:
6899 llvm_unreachable("Already handled above.");
6900 case scVScale:
6901 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6902 case scTruncate: {
6903 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6904 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6905 return setRange(
6906 Trunc, SignHint,
6907 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6908 }
6909 case scZeroExtend: {
6910 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6911 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6912 return setRange(
6913 ZExt, SignHint,
6914 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6915 }
6916 case scSignExtend: {
6917 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6918 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6919 return setRange(
6920 SExt, SignHint,
6921 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6922 }
6923 case scPtrToAddr:
6924 case scPtrToInt: {
6925 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6926 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6927 return setRange(Cast, SignHint, X);
6928 }
6929 case scAddExpr: {
6930 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6931 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6932 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6933 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6934 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6935 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6936 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6937 ConservativeResult =
6938 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6939 }
6940 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6941 unsigned WrapType = OBO::AnyWrap;
6942 if (Add->hasNoSignedWrap())
6943 WrapType |= OBO::NoSignedWrap;
6944 if (Add->hasNoUnsignedWrap())
6945 WrapType |= OBO::NoUnsignedWrap;
6946 for (const SCEV *Op : drop_begin(Add->operands()))
6947 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6948 RangeType);
6949 return setRange(Add, SignHint,
6950 ConservativeResult.intersectWith(X, RangeType));
6951 }
6952 case scMulExpr: {
6953 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6954 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6955 for (const SCEV *Op : drop_begin(Mul->operands()))
6956 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6957 return setRange(Mul, SignHint,
6958 ConservativeResult.intersectWith(X, RangeType));
6959 }
6960 case scUDivExpr: {
6961 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6962 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6963 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6964 return setRange(UDiv, SignHint,
6965 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6966 }
6967 case scAddRecExpr: {
6968 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6969 // If there's no unsigned wrap, the value will never be less than its
6970 // initial value.
6971 if (AddRec->hasNoUnsignedWrap()) {
6972 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6973 if (!UnsignedMinValue.isZero())
6974 ConservativeResult = ConservativeResult.intersectWith(
6975 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6976 }
6977
6978 // If there's no signed wrap, and all the operands except initial value have
6979 // the same sign or zero, the value won't ever be:
6980 // 1: smaller than initial value if operands are non negative,
6981 // 2: bigger than initial value if operands are non positive.
6982 // For both cases, value can not cross signed min/max boundary.
6983 if (AddRec->hasNoSignedWrap()) {
6984 bool AllNonNeg = true;
6985 bool AllNonPos = true;
6986 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6987 if (!isKnownNonNegative(AddRec->getOperand(i)))
6988 AllNonNeg = false;
6989 if (!isKnownNonPositive(AddRec->getOperand(i)))
6990 AllNonPos = false;
6991 }
6992 if (AllNonNeg)
6993 ConservativeResult = ConservativeResult.intersectWith(
6996 RangeType);
6997 else if (AllNonPos)
6998 ConservativeResult = ConservativeResult.intersectWith(
7000 getSignedRangeMax(AddRec->getStart()) +
7001 1),
7002 RangeType);
7003 }
7004
7005 // TODO: non-affine addrec
7006 if (AddRec->isAffine()) {
7007 const SCEV *MaxBEScev =
7009 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7010 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7011
7012 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7013 // MaxBECount's active bits are all <= AddRec's bit width.
7014 if (MaxBECount.getBitWidth() > BitWidth &&
7015 MaxBECount.getActiveBits() <= BitWidth)
7016 MaxBECount = MaxBECount.trunc(BitWidth);
7017 else if (MaxBECount.getBitWidth() < BitWidth)
7018 MaxBECount = MaxBECount.zext(BitWidth);
7019
7020 if (MaxBECount.getBitWidth() == BitWidth) {
7021 auto [RangeFromAffine, Flags] = getRangeForAffineAR(
7022 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7023 ConservativeResult =
7024 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7025 const_cast<SCEVAddRecExpr *>(AddRec)->setNoWrapFlags(Flags);
7026
7027 auto RangeFromFactoring = getRangeViaFactoring(
7028 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7029 ConservativeResult =
7030 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7031 }
7032 }
7033
7034 // Now try symbolic BE count and more powerful methods.
7036 const SCEV *SymbolicMaxBECount =
7038 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7039 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7040 AddRec->hasNoSelfWrap()) {
7041 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7042 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7043 ConservativeResult =
7044 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7045 }
7046 }
7047 }
7048
7049 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7050 }
7051 case scUMaxExpr:
7052 case scSMaxExpr:
7053 case scUMinExpr:
7054 case scSMinExpr:
7055 case scSequentialUMinExpr: {
7057 switch (S->getSCEVType()) {
7058 case scUMaxExpr:
7059 ID = Intrinsic::umax;
7060 break;
7061 case scSMaxExpr:
7062 ID = Intrinsic::smax;
7063 break;
7064 case scUMinExpr:
7066 ID = Intrinsic::umin;
7067 break;
7068 case scSMinExpr:
7069 ID = Intrinsic::smin;
7070 break;
7071 default:
7072 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7073 }
7074
7075 const auto *NAry = cast<SCEVNAryExpr>(S);
7076 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7077 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7078 X = X.intrinsic(
7079 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7080 return setRange(S, SignHint,
7081 ConservativeResult.intersectWith(X, RangeType));
7082 }
7083 case scUnknown: {
7084 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7085 Value *V = U->getValue();
7086
7087 // Check if the IR explicitly contains !range metadata.
7088 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7089 if (MDRange)
7090 ConservativeResult =
7091 ConservativeResult.intersectWith(*MDRange, RangeType);
7092
7093 // Use facts about recurrences in the underlying IR. Note that add
7094 // recurrences are AddRecExprs and thus don't hit this path. This
7095 // primarily handles shift recurrences.
7096 auto CR = getRangeForUnknownRecurrence(U);
7097 ConservativeResult = ConservativeResult.intersectWith(CR);
7098
7099 // See if ValueTracking can give us a useful range.
7100 const DataLayout &DL = getDataLayout();
7101 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7102 if (Known.getBitWidth() != BitWidth)
7103 Known = Known.zextOrTrunc(BitWidth);
7104
7105 // ValueTracking may be able to compute a tighter result for the number of
7106 // sign bits than for the value of those sign bits.
7107 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7108 if (U->getType()->isPointerTy()) {
7109 // If the pointer size is larger than the index size type, this can cause
7110 // NS to be larger than BitWidth. So compensate for this.
7111 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7112 int ptrIdxDiff = ptrSize - BitWidth;
7113 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7114 NS -= ptrIdxDiff;
7115 }
7116
7117 if (NS > 1) {
7118 // If we know any of the sign bits, we know all of the sign bits.
7119 if (!Known.Zero.getHiBits(NS).isZero())
7120 Known.Zero.setHighBits(NS);
7121 if (!Known.One.getHiBits(NS).isZero())
7122 Known.One.setHighBits(NS);
7123 }
7124
7125 if (Known.getMinValue() != Known.getMaxValue() + 1)
7126 ConservativeResult = ConservativeResult.intersectWith(
7127 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7128 RangeType);
7129 if (NS > 1)
7130 ConservativeResult = ConservativeResult.intersectWith(
7131 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7132 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7133 RangeType);
7134
7135 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7136 // Strengthen the range if the underlying IR value is a
7137 // global/alloca/heap allocation using the size of the object.
7138 bool CanBeNull;
7139 uint64_t DerefBytes = V->getPointerDereferenceableBytes(
7140 DL, CanBeNull, /*CanBeFreed=*/nullptr);
7141 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7142 // The highest address the object can start is DerefBytes bytes before
7143 // the end (unsigned max value). If this value is not a multiple of the
7144 // alignment, the last possible start value is the next lowest multiple
7145 // of the alignment. Note: The computations below cannot overflow,
7146 // because if they would there's no possible start address for the
7147 // object.
7148 APInt MaxVal =
7149 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7150 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7151 uint64_t Rem = MaxVal.urem(Align);
7152 MaxVal -= APInt(BitWidth, Rem);
7153 APInt MinVal = APInt::getZero(BitWidth);
7154 if (llvm::isKnownNonZero(V, DL))
7155 MinVal = Align;
7156 ConservativeResult = ConservativeResult.intersectWith(
7157 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7158 }
7159 }
7160
7161 // A range of Phi is a subset of union of all ranges of its input.
7162 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7163 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7164 // AddRecs; return the range for the corresponding AddRec.
7165 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7166 return getRangeRef(AR, SignHint, Depth + 1);
7167
7168 // Make sure that we do not run over cycled Phis.
7169 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7170 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7171
7172 for (const auto &Op : Phi->operands()) {
7173 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7174 RangeFromOps = RangeFromOps.unionWith(OpRange);
7175 // No point to continue if we already have a full set.
7176 if (RangeFromOps.isFullSet())
7177 break;
7178 }
7179 ConservativeResult =
7180 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7181 }
7182 }
7183
7184 // vscale can't be equal to zero
7185 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7186 if (II->getIntrinsicID() == Intrinsic::vscale) {
7187 ConstantRange Disallowed = APInt::getZero(BitWidth);
7188 ConservativeResult = ConservativeResult.difference(Disallowed);
7189 }
7190
7191 return setRange(U, SignHint, std::move(ConservativeResult));
7192 }
7193 case scCouldNotCompute:
7194 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7195 }
7196
7197 return setRange(S, SignHint, std::move(ConservativeResult));
7198}
7199
7200// Given a StartRange, Step and MaxBECount for an expression compute a range of
7201// values that the expression can take. Initially, the expression has a value
7202// from StartRange and then is changed by Step up to MaxBECount times. Signed
7203// argument defines if we treat Step as signed or unsigned. The second return
7204// value indicates that no wrapping occurred.
7205static std::pair<ConstantRange, bool>
7207 const APInt &MaxBECount, bool Signed) {
7208 unsigned BitWidth = Step.getBitWidth();
7209 assert(BitWidth == StartRange.getBitWidth() &&
7210 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7211 // If either Step or MaxBECount is 0, then the expression won't change, and we
7212 // just need to return the initial range.
7213 if (Step == 0 || MaxBECount == 0)
7214 return {StartRange, true};
7215
7216 // If we don't know anything about the initial value (i.e. StartRange is
7217 // FullRange), then we don't know anything about the final range either.
7218 // Return FullRange.
7219 if (StartRange.isFullSet())
7220 return {ConstantRange::getFull(BitWidth), false};
7221
7222 // If Step is signed and negative, then we use its absolute value, but we also
7223 // note that we're moving in the opposite direction.
7224 bool Descending = Signed && Step.isNegative();
7225
7226 if (Signed)
7227 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7228 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7229 // This equations hold true due to the well-defined wrap-around behavior of
7230 // APInt.
7231 Step = Step.abs();
7232
7233 // Check if Offset is more than full span of BitWidth. If it is, the
7234 // expression is guaranteed to overflow.
7235 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7236 return {ConstantRange::getFull(BitWidth), false};
7237
7238 // Offset is by how much the expression can change. Checks above guarantee no
7239 // overflow here.
7240 APInt Offset = Step * MaxBECount;
7241
7242 // Minimum value of the final range will match the minimal value of StartRange
7243 // if the expression is increasing and will be decreased by Offset otherwise.
7244 // Maximum value of the final range will match the maximal value of StartRange
7245 // if the expression is decreasing and will be increased by Offset otherwise.
7246 APInt StartLower = StartRange.getLower();
7247 APInt StartUpper = StartRange.getUpper() - 1;
7248 bool Overflow;
7249 APInt MovedBoundary;
7250 if (Signed) {
7251 // This does not use sadd_ov, as we want to check overflow for a signed
7252 // start with an unsigned offset.
7253 if (Descending) {
7254 MovedBoundary = StartLower - std::move(Offset);
7255 Overflow = MovedBoundary.sgt(StartLower) || StartRange.isSignWrappedSet();
7256 } else {
7257 MovedBoundary = StartUpper + std::move(Offset);
7258 Overflow = MovedBoundary.slt(StartUpper) || StartRange.isSignWrappedSet();
7259 }
7260 } else {
7261 MovedBoundary = StartUpper.uadd_ov(std::move(Offset), Overflow);
7262 Overflow |= StartRange.isWrappedSet();
7263 }
7264
7265 // It's possible that the new minimum/maximum value will fall into the initial
7266 // range (due to wrap around). This means that the expression can take any
7267 // value in this bitwidth, and we have to return full range.
7268 if (StartRange.contains(MovedBoundary))
7269 return {ConstantRange::getFull(BitWidth), false};
7270
7271 APInt NewLower =
7272 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7273 APInt NewUpper =
7274 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7275 NewUpper += 1;
7276
7277 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7278 return {ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper)),
7279 !Overflow};
7280}
7281
7282std::pair<ConstantRange, SCEV::NoWrapFlags>
7283ScalarEvolution::getRangeForAffineAR(const SCEV *Start, const SCEV *Step,
7284 const APInt &MaxBECount) {
7285 assert(getTypeSizeInBits(Start->getType()) ==
7286 getTypeSizeInBits(Step->getType()) &&
7287 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7288 "mismatched bit widths");
7289
7290 // First, consider step signed.
7291 ConstantRange StartSRange = getSignedRange(Start);
7292 ConstantRange StepSRange = getSignedRange(Step);
7293
7294 // If Step can be both positive and negative, we need to find ranges for the
7295 // maximum absolute step values in both directions and union them.
7296 auto [SR1, NSW1] = getRangeForAffineARHelper(
7297 StepSRange.getSignedMin(), StartSRange, MaxBECount, /*Signed=*/true);
7298 auto [SR2, NSW2] = getRangeForAffineARHelper(StepSRange.getSignedMax(),
7299 StartSRange, MaxBECount,
7300 /*Signed=*/true);
7301 ConstantRange SR = SR1.unionWith(SR2);
7302
7303 // Next, consider step unsigned.
7304 auto [UR, NUW] = getRangeForAffineARHelper(
7305 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7306 /*Signed=*/false);
7307
7309 if (NUW)
7311 if (NSW1 && NSW2)
7313
7314 // Finally, intersect signed and unsigned ranges.
7316}
7317
7318ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7319 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7320 ScalarEvolution::RangeSignHint SignHint) {
7321 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7322 assert(AddRec->hasNoSelfWrap() &&
7323 "This only works for non-self-wrapping AddRecs!");
7324 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7325 const SCEV *Step = AddRec->getStepRecurrence(*this);
7326 // Only deal with constant step to save compile time.
7327 if (!isa<SCEVConstant>(Step))
7328 return ConstantRange::getFull(BitWidth);
7329 // Let's make sure that we can prove that we do not self-wrap during
7330 // MaxBECount iterations. We need this because MaxBECount is a maximum
7331 // iteration count estimate, and we might infer nw from some exit for which we
7332 // do not know max exit count (or any other side reasoning).
7333 // TODO: Turn into assert at some point.
7334 if (getTypeSizeInBits(MaxBECount->getType()) >
7335 getTypeSizeInBits(AddRec->getType()))
7336 return ConstantRange::getFull(BitWidth);
7337 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7338 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7339 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7340 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7341 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7342 MaxItersWithoutWrap))
7343 return ConstantRange::getFull(BitWidth);
7344
7345 ICmpInst::Predicate LEPred =
7347 ICmpInst::Predicate GEPred =
7349 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7350
7351 // We know that there is no self-wrap. Let's take Start and End values and
7352 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7353 // the iteration. They either lie inside the range [Min(Start, End),
7354 // Max(Start, End)] or outside it:
7355 //
7356 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7357 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7358 //
7359 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7360 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7361 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7362 // Start <= End and step is positive, or Start >= End and step is negative.
7363 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7364 ConstantRange StartRange = getRangeRef(Start, SignHint);
7365 ConstantRange EndRange = getRangeRef(End, SignHint);
7366 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7367 // If they already cover full iteration space, we will know nothing useful
7368 // even if we prove what we want to prove.
7369 if (RangeBetween.isFullSet())
7370 return RangeBetween;
7371 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7372 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7373 : RangeBetween.isWrappedSet();
7374 if (IsWrappedSet)
7375 return ConstantRange::getFull(BitWidth);
7376
7377 if (isKnownPositive(Step) &&
7378 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7379 return RangeBetween;
7380 if (isKnownNegative(Step) &&
7381 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7382 return RangeBetween;
7383 return ConstantRange::getFull(BitWidth);
7384}
7385
7386ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7387 const SCEV *Step,
7388 const APInt &MaxBECount) {
7389 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7390 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7391
7392 unsigned BitWidth = MaxBECount.getBitWidth();
7393 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7394 getTypeSizeInBits(Step->getType()) == BitWidth &&
7395 "mismatched bit widths");
7396
7397 struct SelectPattern {
7398 Value *Condition = nullptr;
7399 APInt TrueValue;
7400 APInt FalseValue;
7401
7402 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7403 const SCEV *S) {
7404 std::optional<unsigned> CastOp;
7405 APInt Offset(BitWidth, 0);
7406
7408 "Should be!");
7409
7410 // Peel off a constant offset. In the future we could consider being
7411 // smarter here and handle {Start+Step,+,Step} too.
7412 const APInt *Off;
7413 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7414 Offset = *Off;
7415
7416 // Peel off a cast operation
7417 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7418 CastOp = SCast->getSCEVType();
7419 S = SCast->getOperand();
7420 }
7421
7422 using namespace llvm::PatternMatch;
7423
7424 auto *SU = dyn_cast<SCEVUnknown>(S);
7425 const APInt *TrueVal, *FalseVal;
7426 if (!SU ||
7427 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7428 m_APInt(FalseVal)))) {
7429 Condition = nullptr;
7430 return;
7431 }
7432
7433 TrueValue = *TrueVal;
7434 FalseValue = *FalseVal;
7435
7436 // Re-apply the cast we peeled off earlier
7437 if (CastOp)
7438 switch (*CastOp) {
7439 default:
7440 llvm_unreachable("Unknown SCEV cast type!");
7441
7442 case scTruncate:
7443 TrueValue = TrueValue.trunc(BitWidth);
7444 FalseValue = FalseValue.trunc(BitWidth);
7445 break;
7446 case scZeroExtend:
7447 TrueValue = TrueValue.zext(BitWidth);
7448 FalseValue = FalseValue.zext(BitWidth);
7449 break;
7450 case scSignExtend:
7451 TrueValue = TrueValue.sext(BitWidth);
7452 FalseValue = FalseValue.sext(BitWidth);
7453 break;
7454 }
7455
7456 // Re-apply the constant offset we peeled off earlier
7457 TrueValue += Offset;
7458 FalseValue += Offset;
7459 }
7460
7461 bool isRecognized() { return Condition != nullptr; }
7462 };
7463
7464 SelectPattern StartPattern(*this, BitWidth, Start);
7465 if (!StartPattern.isRecognized())
7466 return ConstantRange::getFull(BitWidth);
7467
7468 SelectPattern StepPattern(*this, BitWidth, Step);
7469 if (!StepPattern.isRecognized())
7470 return ConstantRange::getFull(BitWidth);
7471
7472 if (StartPattern.Condition != StepPattern.Condition) {
7473 // We don't handle this case today; but we could, by considering four
7474 // possibilities below instead of two. I'm not sure if there are cases where
7475 // that will help over what getRange already does, though.
7476 return ConstantRange::getFull(BitWidth);
7477 }
7478
7479 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7480 // construct arbitrary general SCEV expressions here. This function is called
7481 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7482 // say) can end up caching a suboptimal value.
7483
7484 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7485 // C2352 and C2512 (otherwise it isn't needed).
7486
7487 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7488 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7489 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7490 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7491
7492 ConstantRange TrueRange =
7493 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount).first;
7494 ConstantRange FalseRange =
7495 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount).first;
7496
7497 return TrueRange.unionWith(FalseRange);
7498}
7499
7500SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7501 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7502 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7503
7504 // Return early if there are no flags to propagate to the SCEV.
7506 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(BinOp);
7507 PDI && PDI->isDisjoint()) {
7509 } else {
7510 if (BinOp->hasNoUnsignedWrap())
7512 if (BinOp->hasNoSignedWrap())
7514 }
7515 if (Flags == SCEV::FlagAnyWrap)
7516 return SCEV::FlagAnyWrap;
7517
7518 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7519}
7520
7521const Instruction *
7522ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7523 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7524 return &*AddRec->getLoop()->getHeader()->begin();
7525 if (auto *U = dyn_cast<SCEVUnknown>(S))
7526 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7527 return I;
7528 return nullptr;
7529}
7530
7531const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7532 bool &Precise) {
7533 Precise = true;
7534 // Do a bounded search of the def relation of the requested SCEVs.
7535 SmallPtrSet<const SCEV *, 16> Visited;
7536 SmallVector<SCEVUse> Worklist;
7537 auto pushOp = [&](const SCEV *S) {
7538 if (!Visited.insert(S).second)
7539 return;
7540 // Threshold of 30 here is arbitrary.
7541 if (Visited.size() > 30) {
7542 Precise = false;
7543 return;
7544 }
7545 Worklist.push_back(S);
7546 };
7547
7548 for (SCEVUse S : Ops)
7549 pushOp(S);
7550
7551 const Instruction *Bound = nullptr;
7552 while (!Worklist.empty()) {
7553 SCEVUse S = Worklist.pop_back_val();
7554 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7555 if (!Bound || DT.dominates(Bound, DefI))
7556 Bound = DefI;
7557 } else {
7558 for (SCEVUse Op : S->operands())
7559 pushOp(Op);
7560 }
7561 }
7562 return Bound ? Bound : &*F.getEntryBlock().begin();
7563}
7564
7565const Instruction *
7566ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7567 bool Discard;
7568 return getDefiningScopeBound(Ops, Discard);
7569}
7570
7571bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7572 const Instruction *B) {
7573 if (A->getParent() == B->getParent() &&
7575 B->getIterator()))
7576 return true;
7577
7578 auto *BLoop = LI.getLoopFor(B->getParent());
7579 if (BLoop && BLoop->getHeader() == B->getParent() &&
7580 BLoop->getLoopPreheader() == A->getParent() &&
7582 A->getParent()->end()) &&
7583 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7584 B->getIterator()))
7585 return true;
7586 return false;
7587}
7588
7590 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7591 visitAll(Op, PC);
7592 return PC.MaybePoison.empty();
7593}
7594
7595bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7596 return !SCEVExprContains(Op, [this](const SCEV *S) {
7597 const SCEV *Op1;
7598 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7599 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7600 // is a non-zero constant, we have to assume the UDiv may be UB.
7601 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7602 });
7603}
7604
7605bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7606 // Only proceed if we can prove that I does not yield poison.
7608 return false;
7609
7610 // At this point we know that if I is executed, then it does not wrap
7611 // according to at least one of NSW or NUW. If I is not executed, then we do
7612 // not know if the calculation that I represents would wrap. Multiple
7613 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7614 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7615 // derived from other instructions that map to the same SCEV. We cannot make
7616 // that guarantee for cases where I is not executed. So we need to find a
7617 // upper bound on the defining scope for the SCEV, and prove that I is
7618 // executed every time we enter that scope. When the bounding scope is a
7619 // loop (the common case), this is equivalent to proving I executes on every
7620 // iteration of that loop.
7621 SmallVector<SCEVUse> SCEVOps;
7622 for (const Use &Op : I->operands()) {
7623 // I could be an extractvalue from a call to an overflow intrinsic.
7624 // TODO: We can do better here in some cases.
7625 if (isSCEVable(Op->getType()))
7626 SCEVOps.push_back(getSCEV(Op));
7627 }
7628 auto *DefI = getDefiningScopeBound(SCEVOps);
7629 return isGuaranteedToTransferExecutionTo(DefI, I);
7630}
7631
7632bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7633 // If we know that \c I can never be poison period, then that's enough.
7634 if (isSCEVExprNeverPoison(I))
7635 return true;
7636
7637 // If the loop only has one exit, then we know that, if the loop is entered,
7638 // any instruction dominating that exit will be executed. If any such
7639 // instruction would result in UB, the addrec cannot be poison.
7640 //
7641 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7642 // also handles uses outside the loop header (they just need to dominate the
7643 // single exit).
7644
7645 auto *ExitingBB = L->getExitingBlock();
7646 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7647 return false;
7648
7649 SmallPtrSet<const Value *, 16> KnownPoison;
7651
7652 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7653 // things that are known to be poison under that assumption go on the
7654 // Worklist.
7655 KnownPoison.insert(I);
7656 Worklist.push_back(I);
7657
7658 while (!Worklist.empty()) {
7659 const Instruction *Poison = Worklist.pop_back_val();
7660
7661 for (const Use &U : Poison->uses()) {
7662 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7663 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7664 DT.dominates(PoisonUser->getParent(), ExitingBB))
7665 return true;
7666
7667 if (propagatesPoison(U) && L->contains(PoisonUser))
7668 if (KnownPoison.insert(PoisonUser).second)
7669 Worklist.push_back(PoisonUser);
7670 }
7671 }
7672
7673 return false;
7674}
7675
7676ScalarEvolution::LoopProperties
7677ScalarEvolution::getLoopProperties(const Loop *L) {
7678 using LoopProperties = ScalarEvolution::LoopProperties;
7679
7680 auto Itr = LoopPropertiesCache.find(L);
7681 if (Itr == LoopPropertiesCache.end()) {
7682 auto HasSideEffects = [](Instruction *I) {
7683 if (auto *SI = dyn_cast<StoreInst>(I))
7684 return !SI->isSimple();
7685
7686 if (I->mayThrow())
7687 return true;
7688
7689 // Non-volatile memset / memcpy do not count as side-effect for forward
7690 // progress.
7691 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7692 return false;
7693
7694 return I->mayWriteToMemory();
7695 };
7696
7697 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7698 /*HasNoSideEffects*/ true};
7699
7700 for (auto *BB : L->getBlocks())
7701 for (auto &I : *BB) {
7703 LP.HasNoAbnormalExits = false;
7704 if (HasSideEffects(&I))
7705 LP.HasNoSideEffects = false;
7706 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7707 break; // We're already as pessimistic as we can get.
7708 }
7709
7710 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7711 assert(InsertPair.second && "We just checked!");
7712 Itr = InsertPair.first;
7713 }
7714
7715 return Itr->second;
7716}
7717
7719 // A mustprogress loop without side effects must be finite.
7720 // TODO: The check used here is very conservative. It's only *specific*
7721 // side effects which are well defined in infinite loops.
7722 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7723}
7724
7725const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7726 // Worklist item with a Value and a bool indicating whether all operands have
7727 // been visited already.
7730
7731 Stack.emplace_back(V, false);
7732 while (!Stack.empty()) {
7733 auto E = Stack.back();
7734 Value *CurV = E.getPointer();
7735
7736 if (getExistingSCEV(CurV)) {
7737 Stack.pop_back();
7738 continue;
7739 }
7740
7742 const SCEV *CreatedSCEV = nullptr;
7743 // If all operands have been visited already, create the SCEV.
7744 if (E.getInt()) {
7745 CreatedSCEV = createSCEV(CurV);
7746 } else {
7747 // Otherwise get the operands we need to create SCEV's for before creating
7748 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7749 // just use it.
7750 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7751 }
7752
7753 if (CreatedSCEV) {
7754 insertValueToMap(CurV, CreatedSCEV);
7755 Stack.pop_back();
7756 } else {
7757 Stack.back().setInt(true);
7758 // Queue its operands which need to be constructed.
7759 for (Value *Op : Ops)
7760 Stack.emplace_back(Op, false);
7761 }
7762 }
7763
7764 return getExistingSCEV(V);
7765}
7766
7767const SCEV *
7768ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7769 if (!isSCEVable(V->getType()))
7770 return getUnknown(V);
7771
7772 if (Instruction *I = dyn_cast<Instruction>(V)) {
7773 // Don't attempt to analyze instructions in blocks that aren't
7774 // reachable. Such instructions don't matter, and they aren't required
7775 // to obey basic rules for definitions dominating uses which this
7776 // analysis depends on.
7777 if (!DT.isReachableFromEntry(I->getParent()))
7778 return getUnknown(PoisonValue::get(V->getType()));
7779 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7780 return getConstant(CI);
7781 else if (isa<GlobalAlias>(V))
7782 return getUnknown(V);
7783 else if (!isa<ConstantExpr>(V))
7784 return getUnknown(V);
7785
7787 if (auto BO =
7789 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7790 switch (BO->Opcode) {
7791 case Instruction::Add:
7792 case Instruction::Mul: {
7793 // For additions and multiplications, traverse add/mul chains for which we
7794 // can potentially create a single SCEV, to reduce the number of
7795 // get{Add,Mul}Expr calls.
7796 do {
7797 if (BO->Op) {
7798 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7799 Ops.push_back(BO->Op);
7800 break;
7801 }
7802 }
7803 Ops.push_back(BO->RHS);
7804 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7806 if (!NewBO ||
7807 (BO->Opcode == Instruction::Add &&
7808 (NewBO->Opcode != Instruction::Add &&
7809 NewBO->Opcode != Instruction::Sub)) ||
7810 (BO->Opcode == Instruction::Mul &&
7811 NewBO->Opcode != Instruction::Mul)) {
7812 Ops.push_back(BO->LHS);
7813 break;
7814 }
7815 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7816 // requires a SCEV for the LHS.
7817 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7818 auto *I = dyn_cast<Instruction>(BO->Op);
7819 if (I && programUndefinedIfPoison(I)) {
7820 Ops.push_back(BO->LHS);
7821 break;
7822 }
7823 }
7824 BO = NewBO;
7825 } while (true);
7826 return nullptr;
7827 }
7828 case Instruction::Sub:
7829 case Instruction::UDiv:
7830 case Instruction::URem:
7831 break;
7832 case Instruction::AShr:
7833 case Instruction::Shl:
7834 case Instruction::Xor:
7835 if (!IsConstArg)
7836 return nullptr;
7837 break;
7838 case Instruction::And:
7839 case Instruction::Or:
7840 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7841 return nullptr;
7842 break;
7843 case Instruction::LShr:
7844 return getUnknown(V);
7845 default:
7846 llvm_unreachable("Unhandled binop");
7847 break;
7848 }
7849
7850 Ops.push_back(BO->LHS);
7851 Ops.push_back(BO->RHS);
7852 return nullptr;
7853 }
7854
7855 switch (U->getOpcode()) {
7856 case Instruction::Trunc:
7857 case Instruction::ZExt:
7858 case Instruction::SExt:
7859 case Instruction::PtrToAddr:
7860 case Instruction::PtrToInt:
7861 Ops.push_back(U->getOperand(0));
7862 return nullptr;
7863
7864 case Instruction::BitCast:
7865 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7866 Ops.push_back(U->getOperand(0));
7867 return nullptr;
7868 }
7869 return getUnknown(V);
7870
7871 case Instruction::SDiv:
7872 case Instruction::SRem:
7873 Ops.push_back(U->getOperand(0));
7874 Ops.push_back(U->getOperand(1));
7875 return nullptr;
7876
7877 case Instruction::GetElementPtr:
7878 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7879 "GEP source element type must be sized");
7880 llvm::append_range(Ops, U->operands());
7881 return nullptr;
7882
7883 case Instruction::IntToPtr:
7884 return getUnknown(V);
7885
7886 case Instruction::PHI:
7887 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7888 // relevant nodes for each of them.
7889 //
7890 // The first is just to call simplifyInstruction, and get something back
7891 // that isn't a PHI.
7892 if (Value *V = simplifyInstruction(
7893 cast<PHINode>(U),
7894 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7895 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7896 assert(V);
7897 Ops.push_back(V);
7898 return nullptr;
7899 }
7900 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7901 // operands which all perform the same operation, but haven't been
7902 // CSE'ed for whatever reason.
7903 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7904 assert(BO);
7905 Ops.push_back(BO);
7906 return nullptr;
7907 }
7908 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7909 // is equivalent to a select, and analyzes it like a select.
7910 {
7911 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7913 assert(Cond);
7914 assert(LHS);
7915 assert(RHS);
7916 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7917 Ops.push_back(CondICmp->getOperand(0));
7918 Ops.push_back(CondICmp->getOperand(1));
7919 }
7920 Ops.push_back(Cond);
7921 Ops.push_back(LHS);
7922 Ops.push_back(RHS);
7923 return nullptr;
7924 }
7925 }
7926 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7927 // so just construct it recursively.
7928 //
7929 // In addition to getNodeForPHI, also construct nodes which might be needed
7930 // by getRangeRef.
7932 for (Value *V : cast<PHINode>(U)->operands())
7933 Ops.push_back(V);
7934 return nullptr;
7935 }
7936 return nullptr;
7937
7938 case Instruction::Select: {
7939 // Check if U is a select that can be simplified to a SCEVUnknown.
7940 auto CanSimplifyToUnknown = [this, U]() {
7941 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7942 return false;
7943
7944 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7945 if (!ICI)
7946 return false;
7947 Value *LHS = ICI->getOperand(0);
7948 Value *RHS = ICI->getOperand(1);
7949 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7950 ICI->getPredicate() == CmpInst::ICMP_NE) {
7952 return true;
7953 } else if (getTypeSizeInBits(LHS->getType()) >
7954 getTypeSizeInBits(U->getType()))
7955 return true;
7956 return false;
7957 };
7958 if (CanSimplifyToUnknown())
7959 return getUnknown(U);
7960
7961 llvm::append_range(Ops, U->operands());
7962 return nullptr;
7963 break;
7964 }
7965 case Instruction::Call:
7966 case Instruction::Invoke:
7967 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7968 Ops.push_back(RV);
7969 return nullptr;
7970 }
7971
7972 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7973 switch (II->getIntrinsicID()) {
7974 case Intrinsic::abs:
7975 Ops.push_back(II->getArgOperand(0));
7976 return nullptr;
7977 case Intrinsic::umax:
7978 case Intrinsic::umin:
7979 case Intrinsic::smax:
7980 case Intrinsic::smin:
7981 case Intrinsic::usub_sat:
7982 case Intrinsic::uadd_sat:
7983 Ops.push_back(II->getArgOperand(0));
7984 Ops.push_back(II->getArgOperand(1));
7985 return nullptr;
7986 case Intrinsic::start_loop_iterations:
7987 case Intrinsic::annotation:
7988 case Intrinsic::ptr_annotation:
7989 Ops.push_back(II->getArgOperand(0));
7990 return nullptr;
7991 default:
7992 break;
7993 }
7994 }
7995 break;
7996 }
7997
7998 return nullptr;
7999}
8000
8001const SCEV *ScalarEvolution::createSCEV(Value *V) {
8002 if (!isSCEVable(V->getType()))
8003 return getUnknown(V);
8004
8005 if (Instruction *I = dyn_cast<Instruction>(V)) {
8006 // Don't attempt to analyze instructions in blocks that aren't
8007 // reachable. Such instructions don't matter, and they aren't required
8008 // to obey basic rules for definitions dominating uses which this
8009 // analysis depends on.
8010 if (!DT.isReachableFromEntry(I->getParent()))
8011 return getUnknown(PoisonValue::get(V->getType()));
8012 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
8013 return getConstant(CI);
8014 else if (isa<GlobalAlias>(V))
8015 return getUnknown(V);
8016 else if (!isa<ConstantExpr>(V))
8017 return getUnknown(V);
8018
8019 const SCEV *LHS;
8020 const SCEV *RHS;
8021
8023 if (auto BO =
8025 switch (BO->Opcode) {
8026 case Instruction::Add: {
8027 // The simple thing to do would be to just call getSCEV on both operands
8028 // and call getAddExpr with the result. However if we're looking at a
8029 // bunch of things all added together, this can be quite inefficient,
8030 // because it leads to N-1 getAddExpr calls for N ultimate operands.
8031 // Instead, gather up all the operands and make a single getAddExpr call.
8032 // LLVM IR canonical form means we need only traverse the left operands.
8034 do {
8035 if (BO->Op) {
8036 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8037 AddOps.push_back(OpSCEV);
8038 break;
8039 }
8040
8041 // If a NUW or NSW flag can be applied to the SCEV for this
8042 // addition, then compute the SCEV for this addition by itself
8043 // with a separate call to getAddExpr. We need to do that
8044 // instead of pushing the operands of the addition onto AddOps,
8045 // since the flags are only known to apply to this particular
8046 // addition - they may not apply to other additions that can be
8047 // formed with operands from AddOps.
8048 const SCEV *RHS = getSCEV(BO->RHS);
8049 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8050 if (Flags != SCEV::FlagAnyWrap) {
8051 const SCEV *LHS = getSCEV(BO->LHS);
8052 if (BO->Opcode == Instruction::Sub)
8053 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8054 else
8055 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8056 break;
8057 }
8058 }
8059
8060 if (BO->Opcode == Instruction::Sub)
8061 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8062 else
8063 AddOps.push_back(getSCEV(BO->RHS));
8064
8065 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8067 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8068 NewBO->Opcode != Instruction::Sub)) {
8069 AddOps.push_back(getSCEV(BO->LHS));
8070 break;
8071 }
8072 BO = NewBO;
8073 } while (true);
8074
8075 return getAddExpr(AddOps);
8076 }
8077
8078 case Instruction::Mul: {
8080 do {
8081 if (BO->Op) {
8082 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8083 MulOps.push_back(OpSCEV);
8084 break;
8085 }
8086
8087 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8088 if (Flags != SCEV::FlagAnyWrap) {
8089 LHS = getSCEV(BO->LHS);
8090 RHS = getSCEV(BO->RHS);
8091 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8092 break;
8093 }
8094 }
8095
8096 MulOps.push_back(getSCEV(BO->RHS));
8097 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8099 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8100 MulOps.push_back(getSCEV(BO->LHS));
8101 break;
8102 }
8103 BO = NewBO;
8104 } while (true);
8105
8106 return getMulExpr(MulOps);
8107 }
8108 case Instruction::UDiv:
8109 LHS = getSCEV(BO->LHS);
8110 RHS = getSCEV(BO->RHS);
8111 return getUDivExpr(LHS, RHS);
8112 case Instruction::URem:
8113 LHS = getSCEV(BO->LHS);
8114 RHS = getSCEV(BO->RHS);
8115 return getURemExpr(LHS, RHS);
8116 case Instruction::Sub: {
8118 if (BO->Op)
8119 Flags = getNoWrapFlagsFromUB(BO->Op);
8120 LHS = getSCEV(BO->LHS);
8121 RHS = getSCEV(BO->RHS);
8122 return getMinusSCEV(LHS, RHS, Flags);
8123 }
8124 case Instruction::And:
8125 // For an expression like x&255 that merely masks off the high bits,
8126 // use zext(trunc(x)) as the SCEV expression.
8127 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8128 if (CI->isZero())
8129 return getSCEV(BO->RHS);
8130 if (CI->isMinusOne())
8131 return getSCEV(BO->LHS);
8132 const APInt &A = CI->getValue();
8133
8134 // Instcombine's ShrinkDemandedConstant may strip bits out of
8135 // constants, obscuring what would otherwise be a low-bits mask.
8136 // Use computeKnownBits to compute what ShrinkDemandedConstant
8137 // knew about to reconstruct a low-bits mask value.
8138 unsigned LZ = A.countl_zero();
8139 unsigned TZ = A.countr_zero();
8140 unsigned BitWidth = A.getBitWidth();
8141 KnownBits Known(BitWidth);
8142 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8143
8144 APInt EffectiveMask =
8145 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8146 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8147 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8148 const SCEV *LHS = getSCEV(BO->LHS);
8149 const SCEV *ShiftedLHS = nullptr;
8150 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8151 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8152 // For an expression like (x * 8) & 8, simplify the multiply.
8153 unsigned MulZeros = OpC->getAPInt().countr_zero();
8154 unsigned GCD = std::min(MulZeros, TZ);
8155 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8157 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8158 append_range(MulOps, LHSMul->operands().drop_front());
8159 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8160 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8161 }
8162 }
8163 if (!ShiftedLHS)
8164 ShiftedLHS = getUDivExpr(LHS, MulCount);
8165 return getMulExpr(
8167 getTruncateExpr(ShiftedLHS,
8168 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8169 BO->LHS->getType()),
8170 MulCount);
8171 }
8172 }
8173 // Binary `and` is a bit-wise `umin`.
8174 if (BO->LHS->getType()->isIntegerTy(1)) {
8175 LHS = getSCEV(BO->LHS);
8176 RHS = getSCEV(BO->RHS);
8177 return getUMinExpr(LHS, RHS);
8178 }
8179 break;
8180
8181 case Instruction::Or:
8182 // Binary `or` is a bit-wise `umax`.
8183 if (BO->LHS->getType()->isIntegerTy(1)) {
8184 LHS = getSCEV(BO->LHS);
8185 RHS = getSCEV(BO->RHS);
8186 return getUMaxExpr(LHS, RHS);
8187 }
8188 break;
8189
8190 case Instruction::Xor:
8191 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8192 // If the RHS of xor is -1, then this is a not operation.
8193 if (CI->isMinusOne())
8194 return getNotSCEV(getSCEV(BO->LHS));
8195
8196 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8197 // This is a variant of the check for xor with -1, and it handles
8198 // the case where instcombine has trimmed non-demanded bits out
8199 // of an xor with -1.
8200 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8201 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8202 if (LBO->getOpcode() == Instruction::And &&
8203 LCI->getValue() == CI->getValue())
8204 if (const SCEVZeroExtendExpr *Z =
8206 Type *UTy = BO->LHS->getType();
8207 const SCEV *Z0 = Z->getOperand();
8208 Type *Z0Ty = Z0->getType();
8209 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8210
8211 // If C is a low-bits mask, the zero extend is serving to
8212 // mask off the high bits. Complement the operand and
8213 // re-apply the zext.
8214 if (CI->getValue().isMask(Z0TySize))
8215 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8216
8217 // If C is a single bit, it may be in the sign-bit position
8218 // before the zero-extend. In this case, represent the xor
8219 // using an add, which is equivalent, and re-apply the zext.
8220 APInt Trunc = CI->getValue().trunc(Z0TySize);
8221 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8222 Trunc.isSignMask())
8223 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8224 UTy);
8225 }
8226 }
8227 break;
8228
8229 case Instruction::Shl:
8230 // Turn shift left of a constant amount into a multiply.
8231 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8232 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8233
8234 // If the shift count is not less than the bitwidth, the result of
8235 // the shift is undefined. Don't try to analyze it, because the
8236 // resolution chosen here may differ from the resolution chosen in
8237 // other parts of the compiler.
8238 if (SA->getValue().uge(BitWidth))
8239 break;
8240
8241 // We can safely preserve the nuw flag in all cases. It's also safe to
8242 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8243 // requires special handling. It can be preserved as long as we're not
8244 // left shifting by bitwidth - 1.
8245 auto Flags = SCEV::FlagAnyWrap;
8246 if (BO->Op) {
8247 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8248 if (any(MulFlags & SCEV::FlagNSW) &&
8249 (any(MulFlags & SCEV::FlagNUW) ||
8250 SA->getValue().ult(BitWidth - 1)))
8252 if (any(MulFlags & SCEV::FlagNUW))
8254 }
8255
8256 ConstantInt *X = ConstantInt::get(
8257 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8258 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8259 }
8260 break;
8261
8262 case Instruction::AShr:
8263 // AShr X, C, where C is a constant.
8264 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8265 if (!CI)
8266 break;
8267
8268 Type *OuterTy = BO->LHS->getType();
8269 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8270 // If the shift count is not less than the bitwidth, the result of
8271 // the shift is undefined. Don't try to analyze it, because the
8272 // resolution chosen here may differ from the resolution chosen in
8273 // other parts of the compiler.
8274 if (CI->getValue().uge(BitWidth))
8275 break;
8276
8277 if (CI->isZero())
8278 return getSCEV(BO->LHS); // shift by zero --> noop
8279
8280 uint64_t AShrAmt = CI->getZExtValue();
8281 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8282
8283 Operator *L = dyn_cast<Operator>(BO->LHS);
8284 const SCEV *AddTruncateExpr = nullptr;
8285 ConstantInt *ShlAmtCI = nullptr;
8286 const SCEV *AddConstant = nullptr;
8287
8288 if (L && L->getOpcode() == Instruction::Add) {
8289 // X = Shl A, n
8290 // Y = Add X, c
8291 // Z = AShr Y, m
8292 // n, c and m are constants.
8293
8294 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8295 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8296 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8297 if (AddOperandCI) {
8298 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8299 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8300 // since we truncate to TruncTy, the AddConstant should be of the
8301 // same type, so create a new Constant with type same as TruncTy.
8302 // Also, the Add constant should be shifted right by AShr amount.
8303 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8304 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8305 // we model the expression as sext(add(trunc(A), c << n)), since the
8306 // sext(trunc) part is already handled below, we create a
8307 // AddExpr(TruncExp) which will be used later.
8308 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8309 }
8310 }
8311 } else if (L && L->getOpcode() == Instruction::Shl) {
8312 // X = Shl A, n
8313 // Y = AShr X, m
8314 // Both n and m are constant.
8315
8316 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8317 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8318 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8319 }
8320
8321 if (AddTruncateExpr && ShlAmtCI) {
8322 // We can merge the two given cases into a single SCEV statement,
8323 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8324 // a simpler case. The following code handles the two cases:
8325 //
8326 // 1) For a two-shift sext-inreg, i.e. n = m,
8327 // use sext(trunc(x)) as the SCEV expression.
8328 //
8329 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8330 // expression. We already checked that ShlAmt < BitWidth, so
8331 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8332 // ShlAmt - AShrAmt < Amt.
8333 const APInt &ShlAmt = ShlAmtCI->getValue();
8334 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8335 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8336 ShlAmtCI->getZExtValue() - AShrAmt);
8337 const SCEV *CompositeExpr =
8338 getMulExpr(AddTruncateExpr, getConstant(Mul));
8339 if (L->getOpcode() != Instruction::Shl)
8340 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8341
8342 return getSignExtendExpr(CompositeExpr, OuterTy);
8343 }
8344 }
8345 break;
8346 }
8347 }
8348
8349 switch (U->getOpcode()) {
8350 case Instruction::Trunc:
8351 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8352
8353 case Instruction::ZExt:
8354 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8355
8356 case Instruction::SExt:
8357 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8359 // The NSW flag of a subtract does not always survive the conversion to
8360 // A + (-1)*B. By pushing sign extension onto its operands we are much
8361 // more likely to preserve NSW and allow later AddRec optimisations.
8362 //
8363 // NOTE: This is effectively duplicating this logic from getSignExtend:
8364 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8365 // but by that point the NSW information has potentially been lost.
8366 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8367 Type *Ty = U->getType();
8368 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8369 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8370 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8371 }
8372 }
8373 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8374
8375 case Instruction::BitCast:
8376 // BitCasts are no-op casts so we just eliminate the cast.
8377 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8378 return getSCEV(U->getOperand(0));
8379 break;
8380
8381 case Instruction::PtrToAddr: {
8382 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8383 if (isa<SCEVCouldNotCompute>(IntOp))
8384 return getUnknown(V);
8385 return IntOp;
8386 }
8387
8388 case Instruction::PtrToInt: {
8389 // Pointer to integer cast is straight-forward, so do model it.
8390 const SCEV *Op = getSCEV(U->getOperand(0));
8391 Type *DstIntTy = U->getType();
8392 // But only if effective SCEV (integer) type is wide enough to represent
8393 // all possible pointer values.
8394 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8395 if (isa<SCEVCouldNotCompute>(IntOp))
8396 return getUnknown(V);
8397 return IntOp;
8398 }
8399 case Instruction::IntToPtr:
8400 // Just don't deal with inttoptr casts.
8401 return getUnknown(V);
8402
8403 case Instruction::SDiv:
8404 // If both operands are non-negative, this is just an udiv.
8405 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8406 isKnownNonNegative(getSCEV(U->getOperand(1))))
8407 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8408 break;
8409
8410 case Instruction::SRem:
8411 // If both operands are non-negative, this is just an urem.
8412 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8413 isKnownNonNegative(getSCEV(U->getOperand(1))))
8414 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8415 break;
8416
8417 case Instruction::GetElementPtr:
8418 return createNodeForGEP(cast<GEPOperator>(U));
8419
8420 case Instruction::PHI:
8421 return createNodeForPHI(cast<PHINode>(U));
8422
8423 case Instruction::Select:
8424 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8425 U->getOperand(2));
8426
8427 case Instruction::Call:
8428 case Instruction::Invoke:
8429 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8430 return getSCEV(RV);
8431
8432 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8433 switch (II->getIntrinsicID()) {
8434 case Intrinsic::abs:
8435 return getAbsExpr(
8436 getSCEV(II->getArgOperand(0)),
8437 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8438 case Intrinsic::umax:
8439 LHS = getSCEV(II->getArgOperand(0));
8440 RHS = getSCEV(II->getArgOperand(1));
8441 return getUMaxExpr(LHS, RHS);
8442 case Intrinsic::umin:
8443 LHS = getSCEV(II->getArgOperand(0));
8444 RHS = getSCEV(II->getArgOperand(1));
8445 return getUMinExpr(LHS, RHS);
8446 case Intrinsic::smax:
8447 LHS = getSCEV(II->getArgOperand(0));
8448 RHS = getSCEV(II->getArgOperand(1));
8449 return getSMaxExpr(LHS, RHS);
8450 case Intrinsic::smin:
8451 LHS = getSCEV(II->getArgOperand(0));
8452 RHS = getSCEV(II->getArgOperand(1));
8453 return getSMinExpr(LHS, RHS);
8454 case Intrinsic::usub_sat: {
8455 const SCEV *X = getSCEV(II->getArgOperand(0));
8456 const SCEV *Y = getSCEV(II->getArgOperand(1));
8457 const SCEV *ClampedY = getUMinExpr(X, Y);
8458 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8459 }
8460 case Intrinsic::uadd_sat: {
8461 const SCEV *X = getSCEV(II->getArgOperand(0));
8462 const SCEV *Y = getSCEV(II->getArgOperand(1));
8463 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8464 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8465 }
8466 case Intrinsic::start_loop_iterations:
8467 case Intrinsic::annotation:
8468 case Intrinsic::ptr_annotation:
8469 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8470 // just eqivalent to the first operand for SCEV purposes.
8471 return getSCEV(II->getArgOperand(0));
8472 case Intrinsic::vscale:
8473 return getVScale(II->getType());
8474 default:
8475 break;
8476 }
8477 }
8478 break;
8479 }
8480
8481 return getUnknown(V);
8482}
8483
8484//===----------------------------------------------------------------------===//
8485// Iteration Count Computation Code
8486//
8487
8489 if (isa<SCEVCouldNotCompute>(ExitCount))
8490 return getCouldNotCompute();
8491
8492 auto *ExitCountType = ExitCount->getType();
8493 assert(ExitCountType->isIntegerTy());
8494 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8495 1 + ExitCountType->getScalarSizeInBits());
8496 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8497}
8498
8500 Type *EvalTy,
8501 const Loop *L) {
8502 if (isa<SCEVCouldNotCompute>(ExitCount))
8503 return getCouldNotCompute();
8504
8505 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8506 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8507
8508 auto CanAddOneWithoutOverflow = [&]() {
8509 ConstantRange ExitCountRange =
8510 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8511 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8512 return true;
8513
8514 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8515 getMinusOne(ExitCount->getType()));
8516 };
8517
8518 // If we need to zero extend the backedge count, check if we can add one to
8519 // it prior to zero extending without overflow. Provided this is safe, it
8520 // allows better simplification of the +1.
8521 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8522 return getZeroExtendExpr(
8523 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8524
8525 // Get the total trip count from the count by adding 1. This may wrap.
8526 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8527}
8528
8529static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8530 if (!ExitCount)
8531 return 0;
8532
8533 ConstantInt *ExitConst = ExitCount->getValue();
8534
8535 // Guard against huge trip counts.
8536 if (ExitConst->getValue().getActiveBits() > 32)
8537 return 0;
8538
8539 // In case of integer overflow, this returns 0, which is correct.
8540 return ((unsigned)ExitConst->getZExtValue()) + 1;
8541}
8542
8544 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8545 return getConstantTripCount(ExitCount);
8546}
8547
8548unsigned
8550 const BasicBlock *ExitingBlock) {
8551 assert(ExitingBlock && "Must pass a non-null exiting block!");
8552 assert(L->isLoopExiting(ExitingBlock) &&
8553 "Exiting block must actually branch out of the loop!");
8554 const SCEVConstant *ExitCount =
8555 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8556 return getConstantTripCount(ExitCount);
8557}
8558
8560 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8561
8562 const auto *MaxExitCount =
8563 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8565 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8566}
8567
8569 SmallVector<BasicBlock *, 8> ExitingBlocks;
8570 L->getExitingBlocks(ExitingBlocks);
8571
8572 std::optional<unsigned> Res;
8573 for (auto *ExitingBB : ExitingBlocks) {
8574 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8575 if (!Res)
8576 Res = Multiple;
8577 Res = std::gcd(*Res, Multiple);
8578 }
8579 return Res.value_or(1);
8580}
8581
8583 const SCEV *ExitCount) {
8584 if (isa<SCEVCouldNotCompute>(ExitCount))
8585 return 1;
8586
8587 // Get the trip count
8588 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8589
8590 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8591 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8592 // the greatest power of 2 divisor less than 2^32.
8593 return Multiple.getActiveBits() > 32
8594 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8595 : (unsigned)Multiple.getZExtValue();
8596}
8597
8598/// Returns the largest constant divisor of the trip count of this loop as a
8599/// normal unsigned value, if possible. This means that the actual trip count is
8600/// always a multiple of the returned value (don't forget the trip count could
8601/// very well be zero as well!).
8602///
8603/// Returns 1 if the trip count is unknown or not guaranteed to be the
8604/// multiple of a constant (which is also the case if the trip count is simply
8605/// constant, use getSmallConstantTripCount for that case), Will also return 1
8606/// if the trip count is very large (>= 2^32).
8607///
8608/// As explained in the comments for getSmallConstantTripCount, this assumes
8609/// that control exits the loop via ExitingBlock.
8610unsigned
8612 const BasicBlock *ExitingBlock) {
8613 assert(ExitingBlock && "Must pass a non-null exiting block!");
8614 assert(L->isLoopExiting(ExitingBlock) &&
8615 "Exiting block must actually branch out of the loop!");
8616 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8617 return getSmallConstantTripMultiple(L, ExitCount);
8618}
8619
8621 const BasicBlock *ExitingBlock,
8622 ExitCountKind Kind) {
8623 switch (Kind) {
8624 case Exact:
8625 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8626 case SymbolicMaximum:
8627 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8628 case ConstantMaximum:
8629 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8630 };
8631 llvm_unreachable("Invalid ExitCountKind!");
8632}
8633
8635 const Loop *L, const BasicBlock *ExitingBlock,
8637 switch (Kind) {
8638 case Exact:
8639 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8640 Predicates);
8641 case SymbolicMaximum:
8642 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8643 Predicates);
8644 case ConstantMaximum:
8645 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8646 Predicates);
8647 };
8648 llvm_unreachable("Invalid ExitCountKind!");
8649}
8650
8653 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8654}
8655
8657 ExitCountKind Kind) {
8658 switch (Kind) {
8659 case Exact:
8660 return getBackedgeTakenInfo(L).getExact(L, this);
8661 case ConstantMaximum:
8662 return getBackedgeTakenInfo(L).getConstantMax(this);
8663 case SymbolicMaximum:
8664 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8665 };
8666 llvm_unreachable("Invalid ExitCountKind!");
8667}
8668
8671 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8672}
8673
8676 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8677}
8678
8680 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8681}
8682
8683ScalarEvolution::BackedgeTakenInfo &
8684ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8685 auto &BTI = getBackedgeTakenInfo(L);
8686 if (BTI.hasFullInfo())
8687 return BTI;
8688
8689 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8690
8691 if (!Pair.second)
8692 return Pair.first->second;
8693
8694 BackedgeTakenInfo Result =
8695 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8696
8697 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8698}
8699
8700ScalarEvolution::BackedgeTakenInfo &
8701ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8702 // Initially insert an invalid entry for this loop. If the insertion
8703 // succeeds, proceed to actually compute a backedge-taken count and
8704 // update the value. The temporary CouldNotCompute value tells SCEV
8705 // code elsewhere that it shouldn't attempt to request a new
8706 // backedge-taken count, which could result in infinite recursion.
8707 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8708 BackedgeTakenCounts.try_emplace(L);
8709 if (!Pair.second)
8710 return Pair.first->second;
8711
8712 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8713 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8714 // must be cleared in this scope.
8715 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8716
8717 // Now that we know more about the trip count for this loop, forget any
8718 // existing SCEV values for PHI nodes in this loop since they are only
8719 // conservative estimates made without the benefit of trip count
8720 // information. This invalidation is not necessary for correctness, and is
8721 // only done to produce more precise results.
8722 if (Result.hasAnyInfo()) {
8723 // Invalidate any expression using an addrec in this loop.
8724 SmallVector<SCEVUse, 8> ToForget;
8725 auto LoopUsersIt = LoopUsers.find(L);
8726 if (LoopUsersIt != LoopUsers.end())
8727 append_range(ToForget, LoopUsersIt->second);
8728 forgetMemoizedResults(ToForget);
8729
8730 // Invalidate constant-evolved loop header phis.
8731 for (PHINode &PN : L->getHeader()->phis())
8732 ConstantEvolutionLoopExitValue.erase(&PN);
8733 }
8734
8735 // Re-lookup the insert position, since the call to
8736 // computeBackedgeTakenCount above could result in a
8737 // recusive call to getBackedgeTakenInfo (on a different
8738 // loop), which would invalidate the iterator computed
8739 // earlier.
8740 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8741}
8742
8744 // This method is intended to forget all info about loops. It should
8745 // invalidate caches as if the following happened:
8746 // - The trip counts of all loops have changed arbitrarily
8747 // - Every llvm::Value has been updated in place to produce a different
8748 // result.
8749 BackedgeTakenCounts.clear();
8750 PredicatedBackedgeTakenCounts.clear();
8751 BECountUsers.clear();
8752 LoopPropertiesCache.clear();
8753 ConstantEvolutionLoopExitValue.clear();
8754 ValueExprMap.clear();
8755 ValuesAtScopes.clear();
8756 ValuesAtScopesUsers.clear();
8757 LoopDispositions.clear();
8758 BlockDispositions.clear();
8759 UnsignedRanges.clear();
8760 SignedRanges.clear();
8761 ExprValueMap.clear();
8762 HasRecMap.clear();
8763 ConstantMultipleCache.clear();
8764 PredicatedSCEVRewrites.clear();
8765 FoldCache.clear();
8766 FoldCacheUser.clear();
8767}
8768void ScalarEvolution::visitAndClearUsers(
8771 SmallVectorImpl<SCEVUse> &ToForget) {
8772 while (!Worklist.empty()) {
8773 Instruction *I = Worklist.pop_back_val();
8774 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8775 continue;
8776
8778 ValueExprMap.find_as(static_cast<Value *>(I));
8779 if (It != ValueExprMap.end()) {
8780 ToForget.push_back(It->second);
8781 eraseValueFromMap(It->first);
8782 if (PHINode *PN = dyn_cast<PHINode>(I))
8783 ConstantEvolutionLoopExitValue.erase(PN);
8784 }
8785
8786 PushDefUseChildren(I, Worklist, Visited);
8787 }
8788}
8789
8791 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8792 SmallVector<SCEVUse, 16> ToForget;
8793
8794 // Iterate over all the loops and sub-loops to drop SCEV information.
8795 while (!LoopWorklist.empty()) {
8796 auto *CurrL = LoopWorklist.pop_back_val();
8797
8798 // Drop any stored trip count value.
8799 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8800 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8801
8802 // Drop information about predicated SCEV rewrites for this loop.
8803 PredicatedSCEVRewrites.remove_if(
8804 [&](const auto &Entry) { return Entry.first.second == CurrL; });
8805
8806 auto LoopUsersItr = LoopUsers.find(CurrL);
8807 if (LoopUsersItr != LoopUsers.end())
8808 llvm::append_range(ToForget, LoopUsersItr->second);
8809
8810 // Drop information about expressions based on loop-header PHIs.
8811 for (PHINode &PN : CurrL->getHeader()->phis()) {
8812 ConstantEvolutionLoopExitValue.erase(&PN);
8813 auto VIt = ValueExprMap.find_as(static_cast<Value *>(&PN));
8814 if (VIt != ValueExprMap.end())
8815 ToForget.push_back(VIt->second);
8816 }
8817
8818 LoopPropertiesCache.erase(CurrL);
8819 // Forget all contained loops too, to avoid dangling entries in the
8820 // ValuesAtScopes map.
8821 LoopWorklist.append(CurrL->begin(), CurrL->end());
8822 }
8823 forgetMemoizedResults(ToForget);
8824}
8825
8827 forgetLoop(L->getOutermostLoop());
8828}
8829
8832 if (!I) return;
8833
8834 // Drop information about expressions based on loop-header PHIs.
8837 SmallVector<SCEVUse, 8> ToForget;
8838 Worklist.push_back(I);
8839 Visited.insert(I);
8840 visitAndClearUsers(Worklist, Visited, ToForget);
8841
8842 forgetMemoizedResults(ToForget);
8843}
8844
8846 if (!isSCEVable(V->getType()))
8847 return;
8848
8849 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8850 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8851 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8852 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8853 if (const SCEV *S = getExistingSCEV(V)) {
8854 struct InvalidationRootCollector {
8855 Loop *L;
8857
8858 InvalidationRootCollector(Loop *L) : L(L) {}
8859
8860 bool follow(const SCEV *S) {
8861 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8862 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8863 if (L->contains(I))
8864 Roots.push_back(S);
8865 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8866 if (L->contains(AddRec->getLoop()))
8867 Roots.push_back(S);
8868 }
8869 return true;
8870 }
8871 bool isDone() const { return false; }
8872 };
8873
8874 InvalidationRootCollector C(L);
8875 visitAll(S, C);
8876 forgetMemoizedResults(C.Roots);
8877 }
8878
8879 // Also perform the normal invalidation.
8880 forgetValue(V);
8881}
8882
8883void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8884
8886 // Unless a specific value is passed to invalidation, completely clear both
8887 // caches.
8888 if (!V) {
8889 BlockDispositions.clear();
8890 LoopDispositions.clear();
8891 return;
8892 }
8893
8894 if (!isSCEVable(V->getType()))
8895 return;
8896
8897 const SCEV *S = getExistingSCEV(V);
8898 if (!S)
8899 return;
8900
8901 // Invalidate the block and loop dispositions cached for S. Dispositions of
8902 // S's users may change if S's disposition changes (i.e. a user may change to
8903 // loop-invariant, if S changes to loop invariant), so also invalidate
8904 // dispositions of S's users recursively.
8905 SmallVector<SCEVUse, 8> Worklist = {S};
8907 while (!Worklist.empty()) {
8908 const SCEV *Curr = Worklist.pop_back_val();
8909 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8910 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8911 if (!LoopDispoRemoved && !BlockDispoRemoved)
8912 continue;
8913 auto Users = SCEVUsers.find(Curr);
8914 if (Users != SCEVUsers.end())
8915 for (const auto *User : Users->second)
8916 if (Seen.insert(User).second)
8917 Worklist.push_back(User);
8918 }
8919}
8920
8921/// Get the exact loop backedge taken count considering all loop exits. A
8922/// computable result can only be returned for loops with all exiting blocks
8923/// dominating the latch. howFarToZero assumes that the limit of each loop test
8924/// is never skipped. This is a valid assumption as long as the loop exits via
8925/// that test. For precise results, it is the caller's responsibility to specify
8926/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8927const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8928 const Loop *L, ScalarEvolution *SE,
8930 // If any exits were not computable, the loop is not computable.
8931 if (!isComplete() || ExitNotTaken.empty())
8932 return SE->getCouldNotCompute();
8933
8934 const BasicBlock *Latch = L->getLoopLatch();
8935 // All exiting blocks we have collected must dominate the only backedge.
8936 if (!Latch)
8937 return SE->getCouldNotCompute();
8938
8939 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8940 // count is simply a minimum out of all these calculated exit counts.
8942 for (const auto &ENT : ExitNotTaken) {
8943 const SCEV *BECount = ENT.ExactNotTaken;
8944 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8945 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8946 "We should only have known counts for exiting blocks that dominate "
8947 "latch!");
8948
8949 Ops.push_back(BECount);
8950
8951 if (Preds)
8952 append_range(*Preds, ENT.Predicates);
8953
8954 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8955 "Predicate should be always true!");
8956 }
8957
8958 // If an earlier exit exits on the first iteration (exit count zero), then
8959 // a later poison exit count should not propagate into the result. This are
8960 // exactly the semantics provided by umin_seq.
8961 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8962}
8963
8964const ScalarEvolution::ExitNotTakenInfo *
8965ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8966 const BasicBlock *ExitingBlock,
8967 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8968 for (const auto &ENT : ExitNotTaken)
8969 if (ENT.ExitingBlock == ExitingBlock) {
8970 if (ENT.hasAlwaysTruePredicate())
8971 return &ENT;
8972 else if (Predicates) {
8973 append_range(*Predicates, ENT.Predicates);
8974 return &ENT;
8975 }
8976 }
8977
8978 return nullptr;
8979}
8980
8981/// getConstantMax - Get the constant max backedge taken count for the loop.
8982const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8983 ScalarEvolution *SE,
8984 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8985 if (!getConstantMax())
8986 return SE->getCouldNotCompute();
8987
8988 for (const auto &ENT : ExitNotTaken)
8989 if (!ENT.hasAlwaysTruePredicate()) {
8990 if (!Predicates)
8991 return SE->getCouldNotCompute();
8992 append_range(*Predicates, ENT.Predicates);
8993 }
8994
8995 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8996 isa<SCEVConstant>(getConstantMax())) &&
8997 "No point in having a non-constant max backedge taken count!");
8998 return getConstantMax();
8999}
9000
9001const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
9002 const Loop *L, ScalarEvolution *SE,
9003 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
9004 if (!SymbolicMax) {
9005 // Form an expression for the maximum exit count possible for this loop. We
9006 // merge the max and exact information to approximate a version of
9007 // getConstantMaxBackedgeTakenCount which isn't restricted to just
9008 // constants.
9009 SmallVector<SCEVUse, 4> ExitCounts;
9010
9011 for (const auto &ENT : ExitNotTaken) {
9012 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
9013 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
9014 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
9015 "We should only have known counts for exiting blocks that "
9016 "dominate latch!");
9017 ExitCounts.push_back(ExitCount);
9018 if (Predicates)
9019 append_range(*Predicates, ENT.Predicates);
9020
9021 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9022 "Predicate should be always true!");
9023 }
9024 }
9025 if (ExitCounts.empty())
9026 SymbolicMax = SE->getCouldNotCompute();
9027 else
9028 SymbolicMax =
9029 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9030 }
9031 return SymbolicMax;
9032}
9033
9034bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9035 ScalarEvolution *SE) const {
9036 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9037 return !ENT.hasAlwaysTruePredicate();
9038 };
9039 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9040}
9041
9044
9046 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9047 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9051 // If we prove the max count is zero, so is the symbolic bound. This happens
9052 // in practice due to differences in a) how context sensitive we've chosen
9053 // to be and b) how we reason about bounds implied by UB.
9054 if (ConstantMaxNotTaken->isZero()) {
9055 this->ExactNotTaken = E = ConstantMaxNotTaken;
9056 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9057 }
9058
9061 "Exact is not allowed to be less precise than Constant Max");
9064 "Exact is not allowed to be less precise than Symbolic Max");
9067 "Symbolic Max is not allowed to be less precise than Constant Max");
9070 "No point in having a non-constant max backedge taken count!");
9072 for (const auto PredList : PredLists)
9073 for (const auto *P : PredList) {
9074 if (SeenPreds.contains(P))
9075 continue;
9076 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9077 SeenPreds.insert(P);
9078 Predicates.push_back(P);
9079 }
9080 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9081 "Backedge count should be int");
9083 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9084 "Max backedge count should be int");
9085}
9086
9094
9095/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9096/// computable exit into a persistent ExitNotTakenInfo array.
9097ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9099 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9100 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9101 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9102
9103 ExitNotTaken.reserve(ExitCounts.size());
9104 std::transform(ExitCounts.begin(), ExitCounts.end(),
9105 std::back_inserter(ExitNotTaken),
9106 [&](const EdgeExitInfo &EEI) {
9107 BasicBlock *ExitBB = EEI.first;
9108 const ExitLimit &EL = EEI.second;
9109 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9110 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9111 EL.Predicates);
9112 });
9113 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9114 isa<SCEVConstant>(ConstantMax)) &&
9115 "No point in having a non-constant max backedge taken count!");
9116}
9117
9118/// Compute the number of times the backedge of the specified loop will execute.
9119ScalarEvolution::BackedgeTakenInfo
9120ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9121 bool AllowPredicates) {
9122 SmallVector<BasicBlock *, 8> ExitingBlocks;
9123 L->getExitingBlocks(ExitingBlocks);
9124
9125 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9126
9128 bool CouldComputeBECount = true;
9129 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9130 const SCEV *MustExitMaxBECount = nullptr;
9131 const SCEV *MayExitMaxBECount = nullptr;
9132 bool MustExitMaxOrZero = false;
9133 bool IsOnlyExit = ExitingBlocks.size() == 1;
9134
9135 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9136 // and compute maxBECount.
9137 // Do a union of all the predicates here.
9138 for (BasicBlock *ExitBB : ExitingBlocks) {
9139 // We canonicalize untaken exits to br (constant), ignore them so that
9140 // proving an exit untaken doesn't negatively impact our ability to reason
9141 // about the loop as whole.
9142 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9143 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9144 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9145 if (ExitIfTrue == CI->isZero())
9146 continue;
9147 }
9148
9149 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9150
9151 assert((AllowPredicates || EL.Predicates.empty()) &&
9152 "Predicated exit limit when predicates are not allowed!");
9153
9154 // 1. For each exit that can be computed, add an entry to ExitCounts.
9155 // CouldComputeBECount is true only if all exits can be computed.
9156 if (EL.ExactNotTaken != getCouldNotCompute())
9157 ++NumExitCountsComputed;
9158 else
9159 // We couldn't compute an exact value for this exit, so
9160 // we won't be able to compute an exact value for the loop.
9161 CouldComputeBECount = false;
9162 // Remember exit count if either exact or symbolic is known. Because
9163 // Exact always implies symbolic, only check symbolic.
9164 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9165 ExitCounts.emplace_back(ExitBB, EL);
9166 else {
9167 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9168 "Exact is known but symbolic isn't?");
9169 ++NumExitCountsNotComputed;
9170 }
9171
9172 // 2. Derive the loop's MaxBECount from each exit's max number of
9173 // non-exiting iterations. Partition the loop exits into two kinds:
9174 // LoopMustExits and LoopMayExits.
9175 //
9176 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9177 // is a LoopMayExit. If any computable LoopMustExit is found, then
9178 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9179 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9180 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9181 // any
9182 // computable EL.ConstantMaxNotTaken.
9183 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9184 DT.dominates(ExitBB, Latch)) {
9185 if (!MustExitMaxBECount) {
9186 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9187 MustExitMaxOrZero = EL.MaxOrZero;
9188 } else {
9189 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9190 EL.ConstantMaxNotTaken);
9191 }
9192 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9193 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9194 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9195 else {
9196 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9197 EL.ConstantMaxNotTaken);
9198 }
9199 }
9200 }
9201 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9202 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9203 // The loop backedge will be taken the maximum or zero times if there's
9204 // a single exit that must be taken the maximum or zero times.
9205 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9206
9207 // Remember which SCEVs are used in exit limits for invalidation purposes.
9208 // We only care about non-constant SCEVs here, so we can ignore
9209 // EL.ConstantMaxNotTaken
9210 // and MaxBECount, which must be SCEVConstant.
9211 for (const auto &Pair : ExitCounts) {
9212 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9213 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9214 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9215 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9216 {L, AllowPredicates});
9217 }
9218 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9219 MaxBECount, MaxOrZero);
9220}
9221
9222ScalarEvolution::ExitLimit
9223ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9224 bool IsOnlyExit, bool AllowPredicates) {
9225 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9226 // If our exiting block does not dominate the latch, then its connection with
9227 // loop's exit limit may be far from trivial.
9228 const BasicBlock *Latch = L->getLoopLatch();
9229 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9230 return getCouldNotCompute();
9231
9232 Instruction *Term = ExitingBlock->getTerminator();
9233 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9234 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9235 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9236 "It should have one successor in loop and one exit block!");
9237 // Proceed to the next level to examine the exit condition expression.
9238 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9239 /*ControlsOnlyExit=*/IsOnlyExit,
9240 AllowPredicates);
9241 }
9242
9243 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9244 // For switch, make sure that there is a single exit from the loop.
9245 BasicBlock *Exit = nullptr;
9246 for (auto *SBB : successors(ExitingBlock))
9247 if (!L->contains(SBB)) {
9248 if (Exit) // Multiple exit successors.
9249 return getCouldNotCompute();
9250 Exit = SBB;
9251 }
9252 assert(Exit && "Exiting block must have at least one exit");
9253 return computeExitLimitFromSingleExitSwitch(
9254 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9255 }
9256
9257 return getCouldNotCompute();
9258}
9259
9261 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9262 bool AllowPredicates) {
9263 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9264 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9265 ControlsOnlyExit, AllowPredicates);
9266}
9267
9268std::optional<ScalarEvolution::ExitLimit>
9269ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9270 bool ExitIfTrue, bool ControlsOnlyExit,
9271 bool AllowPredicates) {
9272 (void)this->L;
9273 (void)this->ExitIfTrue;
9274 (void)this->AllowPredicates;
9275
9276 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9277 this->AllowPredicates == AllowPredicates &&
9278 "Variance in assumed invariant key components!");
9279 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9280 if (Itr == TripCountMap.end())
9281 return std::nullopt;
9282 return Itr->second;
9283}
9284
9285void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9286 bool ExitIfTrue,
9287 bool ControlsOnlyExit,
9288 bool AllowPredicates,
9289 const ExitLimit &EL) {
9290 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9291 this->AllowPredicates == AllowPredicates &&
9292 "Variance in assumed invariant key components!");
9293
9294 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9295 assert(InsertResult.second && "Expected successful insertion!");
9296 (void)InsertResult;
9297 (void)ExitIfTrue;
9298}
9299
9300ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9301 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9302 bool ControlsOnlyExit, bool AllowPredicates) {
9303
9304 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9305 AllowPredicates))
9306 return *MaybeEL;
9307
9308 ExitLimit EL = computeExitLimitFromCondImpl(
9309 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9310 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9311 return EL;
9312}
9313
9314ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9315 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9316 bool ControlsOnlyExit, bool AllowPredicates) {
9317 // Handle BinOp conditions (And, Or).
9318 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9319 Cache, L, ExitCond, ExitIfTrue, AllowPredicates))
9320 return *LimitFromBinOp;
9321
9322 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9323 // Proceed to the next level to examine the icmp.
9324 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9325 ExitLimit EL =
9326 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9327 if (EL.hasFullInfo() || !AllowPredicates)
9328 return EL;
9329
9330 // Try again, but use SCEV predicates this time.
9331 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9332 ControlsOnlyExit,
9333 /*AllowPredicates=*/true);
9334 }
9335
9336 // Check for a constant condition. These are normally stripped out by
9337 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9338 // preserve the CFG and is temporarily leaving constant conditions
9339 // in place.
9340 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9341 if (ExitIfTrue == !CI->getZExtValue())
9342 // The backedge is always taken.
9343 return getCouldNotCompute();
9344 // The backedge is never taken.
9345 return getZero(CI->getType());
9346 }
9347
9348 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9349 // with a constant step, we can form an equivalent icmp predicate and figure
9350 // out how many iterations will be taken before we exit.
9351 const WithOverflowInst *WO;
9352 const APInt *C;
9353 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9354 match(WO->getRHS(), m_APInt(C))) {
9355 ConstantRange NWR =
9357 WO->getNoWrapKind());
9358 CmpInst::Predicate Pred;
9359 APInt NewRHSC, Offset;
9360 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9361 if (!ExitIfTrue)
9362 Pred = ICmpInst::getInversePredicate(Pred);
9363 auto *LHS = getSCEV(WO->getLHS());
9364 if (Offset != 0)
9366 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9367 ControlsOnlyExit, AllowPredicates);
9368 if (EL.hasAnyInfo())
9369 return EL;
9370 }
9371
9372 // If it's not an integer or pointer comparison then compute it the hard way.
9373 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9374}
9375
9376std::optional<ScalarEvolution::ExitLimit>
9377ScalarEvolution::computeExitLimitFromCondFromBinOp(ExitLimitCacheTy &Cache,
9378 const Loop *L,
9379 Value *ExitCond,
9380 bool ExitIfTrue,
9381 bool AllowPredicates) {
9382 // Check if the controlling expression for this loop is an And or Or.
9383 Value *Op0, *Op1;
9384 bool IsAnd;
9385 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9386 IsAnd = true;
9387 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9388 IsAnd = false;
9389 else
9390 return std::nullopt;
9391
9392 // A sub-condition of a non-trivial binop never solely controls the exit,
9393 // whether we exit always depends on both conditions.
9394 ExitLimit EL0 = computeExitLimitFromCondCached(
9395 Cache, L, Op0, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9396 ExitLimit EL1 = computeExitLimitFromCondCached(
9397 Cache, L, Op1, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9398
9399 // EitherMayExit is true in these two cases:
9400 // br (and Op0 Op1), loop, exit
9401 // br (or Op0 Op1), exit, loop
9402 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9403
9404 const SCEV *BECount = getCouldNotCompute();
9405 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9406 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9407 if (EitherMayExit) {
9408 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9409 // Both conditions must be same for the loop to continue executing.
9410 // Choose the less conservative count.
9411 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9412 EL1.ExactNotTaken != getCouldNotCompute()) {
9413 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9414 UseSequentialUMin);
9415 }
9416 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9417 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9418 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9419 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9420 else
9421 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9422 EL1.ConstantMaxNotTaken);
9423 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9424 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9425 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9426 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9427 else
9428 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9429 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9430 } else {
9431 // Both conditions must be same at the same time for the loop to exit.
9432 // For now, be conservative.
9433 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9434 BECount = EL0.ExactNotTaken;
9435 }
9436
9437 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9438 // to be more aggressive when computing BECount than when computing
9439 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9440 // and
9441 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9442 // EL1.ConstantMaxNotTaken to not.
9443 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9444 !isa<SCEVCouldNotCompute>(BECount))
9445 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9446 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9447 SymbolicMaxBECount =
9448 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9449 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9450 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9451}
9452
9453ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9454 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9455 bool AllowPredicates) {
9456 // If the condition was exit on true, convert the condition to exit on false
9457 CmpPredicate Pred;
9458 if (!ExitIfTrue)
9459 Pred = ExitCond->getCmpPredicate();
9460 else
9461 Pred = ExitCond->getInverseCmpPredicate();
9462 const ICmpInst::Predicate OriginalPred = Pred;
9463
9464 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9465 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9466
9467 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9468 AllowPredicates);
9469 if (EL.hasAnyInfo())
9470 return EL;
9471
9472 auto *ExhaustiveCount =
9473 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9474
9475 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9476 return ExhaustiveCount;
9477
9478 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9479 ExitCond->getOperand(1), L, OriginalPred);
9480}
9481ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9482 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9483 bool ControlsOnlyExit, bool AllowPredicates) {
9484
9485 // Try to evaluate any dependencies out of the loop.
9486 LHS = getSCEVAtScope(LHS, L);
9487 RHS = getSCEVAtScope(RHS, L);
9488
9489 // At this point, we would like to compute how many iterations of the
9490 // loop the predicate will return true for these inputs.
9491 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9492 // If there is a loop-invariant, force it into the RHS.
9493 std::swap(LHS, RHS);
9495 }
9496
9497 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9499 // Simplify the operands before analyzing them.
9500 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9501
9502 // If we have a comparison of a chrec against a constant, try to use value
9503 // ranges to answer this query.
9504 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9505 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9506 if (AddRec->getLoop() == L) {
9507 // Form the constant range.
9508 ConstantRange CompRange =
9509 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9510
9511 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9512 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9513 }
9514
9515 // If this loop must exit based on this condition (or execute undefined
9516 // behaviour), see if we can improve wrap flags. This is essentially
9517 // a must execute style proof.
9518 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9519 // If we can prove the test sequence produced must repeat the same values
9520 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9521 // because if it did, we'd have an infinite (undefined) loop.
9522 // TODO: We can peel off any functions which are invertible *in L*. Loop
9523 // invariant terms are effectively constants for our purposes here.
9524 SCEVUse InnerLHS = LHS;
9525 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9526 InnerLHS = ZExt->getOperand();
9527 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9528 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9529 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9530 /*OrNegative=*/true)) {
9531 auto Flags = AR->getNoWrapFlags();
9532 Flags = setFlags(Flags, SCEV::FlagNW);
9533 SmallVector<SCEVUse> Operands{AR->operands()};
9534 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9535 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9536 }
9537
9538 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9539 // From no-self-wrap, this follows trivially from the fact that every
9540 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9541 // last value before (un)signed wrap. Since we know that last value
9542 // didn't exit, nor will any smaller one.
9543 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9544 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9545 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9546 AR && AR->getLoop() == L && AR->isAffine() &&
9547 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9548 isKnownPositive(AR->getStepRecurrence(*this))) {
9549 auto Flags = AR->getNoWrapFlags();
9550 Flags = setFlags(Flags, WrapType);
9551 SmallVector<SCEVUse> Operands{AR->operands()};
9552 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9553 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9554 }
9555 }
9556 }
9557
9558 switch (Pred) {
9559 case ICmpInst::ICMP_NE: { // while (X != Y)
9560 // Convert to: while (X-Y != 0)
9561 if (LHS->getType()->isPointerTy()) {
9564 return LHS;
9565 }
9566 if (RHS->getType()->isPointerTy()) {
9569 return RHS;
9570 }
9571 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9572 AllowPredicates);
9573 if (EL.hasAnyInfo())
9574 return EL;
9575 break;
9576 }
9577 case ICmpInst::ICMP_EQ: { // while (X == Y)
9578 // Convert to: while (X-Y == 0)
9579 if (LHS->getType()->isPointerTy()) {
9582 return LHS;
9583 }
9584 if (RHS->getType()->isPointerTy()) {
9587 return RHS;
9588 }
9589 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9590 if (EL.hasAnyInfo()) return EL;
9591 break;
9592 }
9593 case ICmpInst::ICMP_SLE:
9594 case ICmpInst::ICMP_ULE:
9595 // Since the loop is finite, an invariant RHS cannot include the boundary
9596 // value, otherwise it would loop forever.
9597 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9598 !isLoopInvariant(RHS, L)) {
9599 // Otherwise, perform the addition in a wider type, to avoid overflow.
9600 // If the LHS is an addrec with the appropriate nowrap flag, the
9601 // extension will be sunk into it and the exit count can be analyzed.
9602 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9603 if (!OldType)
9604 break;
9605 // Prefer doubling the bitwidth over adding a single bit to make it more
9606 // likely that we use a legal type.
9607 auto *NewType =
9608 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9609 if (ICmpInst::isSigned(Pred)) {
9610 LHS = getSignExtendExpr(LHS, NewType);
9611 RHS = getSignExtendExpr(RHS, NewType);
9612 } else {
9613 LHS = getZeroExtendExpr(LHS, NewType);
9614 RHS = getZeroExtendExpr(RHS, NewType);
9615 }
9616 }
9618 [[fallthrough]];
9619 case ICmpInst::ICMP_SLT:
9620 case ICmpInst::ICMP_ULT: { // while (X < Y)
9621 bool IsSigned = ICmpInst::isSigned(Pred);
9622 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9623 AllowPredicates);
9624 if (EL.hasAnyInfo())
9625 return EL;
9626 break;
9627 }
9628 case ICmpInst::ICMP_SGE:
9629 case ICmpInst::ICMP_UGE:
9630 // Since the loop is finite, an invariant RHS cannot include the boundary
9631 // value, otherwise it would loop forever.
9632 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9633 !isLoopInvariant(RHS, L))
9634 break;
9636 [[fallthrough]];
9637 case ICmpInst::ICMP_SGT:
9638 case ICmpInst::ICMP_UGT: { // while (X > Y)
9639 bool IsSigned = ICmpInst::isSigned(Pred);
9640 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9641 AllowPredicates);
9642 if (EL.hasAnyInfo())
9643 return EL;
9644 break;
9645 }
9646 default:
9647 break;
9648 }
9649
9650 return getCouldNotCompute();
9651}
9652
9653ScalarEvolution::ExitLimit
9654ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9655 SwitchInst *Switch,
9656 BasicBlock *ExitingBlock,
9657 bool ControlsOnlyExit) {
9658 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9659
9660 // Give up if the exit is the default dest of a switch.
9661 if (Switch->getDefaultDest() == ExitingBlock)
9662 return getCouldNotCompute();
9663
9664 assert(L->contains(Switch->getDefaultDest()) &&
9665 "Default case must not exit the loop!");
9666 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9667 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9668
9669 // while (X != Y) --> while (X-Y != 0)
9670 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9671 if (EL.hasAnyInfo())
9672 return EL;
9673
9674 return getCouldNotCompute();
9675}
9676
9677static ConstantInt *
9679 ScalarEvolution &SE) {
9680 const SCEV *InVal = SE.getConstant(C);
9681 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9683 "Evaluation of SCEV at constant didn't fold correctly?");
9684 return cast<SCEVConstant>(Val)->getValue();
9685}
9686
9687ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9688 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9689 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9690 if (!RHS)
9691 return getCouldNotCompute();
9692
9693 const BasicBlock *Latch = L->getLoopLatch();
9694 if (!Latch)
9695 return getCouldNotCompute();
9696
9697 const BasicBlock *Predecessor = L->getLoopPredecessor();
9698 if (!Predecessor)
9699 return getCouldNotCompute();
9700
9701 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9702 // Return LHS in OutLHS, shift_op in OutOpCode, and the shift amount in
9703 // OutShiftAmt.
9704 auto MatchPositiveShift = [](Value *V, Value *&OutLHS,
9705 Instruction::BinaryOps &OutOpCode,
9706 unsigned &OutShiftAmt) {
9707 using namespace PatternMatch;
9708
9709 ConstantInt *ShiftAmt;
9710 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9711 OutOpCode = Instruction::LShr;
9712 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9713 OutOpCode = Instruction::AShr;
9714 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9715 OutOpCode = Instruction::Shl;
9716 else
9717 return false;
9718
9719 uint64_t Amt = ShiftAmt->getValue().getLimitedValue();
9720 if (Amt == 0 || Amt >= OutLHS->getType()->getScalarSizeInBits())
9721 return false;
9722 OutShiftAmt = Amt;
9723 return true;
9724 };
9725
9726 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9727 //
9728 // loop:
9729 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9730 // %iv.shifted = lshr i32 %iv, <positive constant>
9731 //
9732 // Return true on a successful match. Return the corresponding PHI node (%iv
9733 // above) in PNOut, the opcode of the shift operation in OpCodeOut, and the
9734 // shift amount in ShiftAmtOut.
9735 auto MatchShiftRecurrence = [&](Value *V, PHINode *&PNOut,
9736 Instruction::BinaryOps &OpCodeOut,
9737 unsigned &ShiftAmtOut) {
9738 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9739
9740 {
9742 Value *V;
9743 unsigned Amt;
9744
9745 // If we encounter a shift instruction, "peel off" the shift operation,
9746 // and remember that we did so. Later when we inspect %iv's backedge
9747 // value, we will make sure that the backedge value uses the same
9748 // operation.
9749 //
9750 // Note: the peeled shift operation does not have to be the same
9751 // instruction as the one feeding into the PHI's backedge value. We only
9752 // really care about it being the same *kind* of shift instruction --
9753 // that's all that is required for our later inferences to hold.
9754 if (MatchPositiveShift(LHS, V, OpC, Amt)) {
9755 PostShiftOpCode = OpC;
9756 LHS = V;
9757 }
9758 }
9759
9760 PNOut = dyn_cast<PHINode>(LHS);
9761 if (!PNOut || PNOut->getParent() != L->getHeader())
9762 return false;
9763
9764 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9765 Value *OpLHS;
9766
9767 return
9768 // The backedge value for the PHI node must be a shift by a positive
9769 // amount
9770 MatchPositiveShift(BEValue, OpLHS, OpCodeOut, ShiftAmtOut) &&
9771
9772 // of the PHI node itself
9773 OpLHS == PNOut &&
9774
9775 // and the kind of shift should be match the kind of shift we peeled
9776 // off, if any.
9777 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9778 };
9779
9780 PHINode *PN;
9782 unsigned ShiftAmt;
9783 if (!MatchShiftRecurrence(LHS, PN, OpCode, ShiftAmt))
9784 return getCouldNotCompute();
9785
9786 const DataLayout &DL = getDataLayout();
9787
9788 // The key rationale for this optimization is that for some kinds of shift
9789 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9790 // within a finite number of iterations. If the condition guarding the
9791 // backedge (in the sense that the backedge is taken if the condition is true)
9792 // is false for the value the shift recurrence stabilizes to, then we know
9793 // that the backedge is taken only a finite number of times.
9794
9795 ConstantInt *StableValue = nullptr;
9796 switch (OpCode) {
9797 default:
9798 llvm_unreachable("Impossible case!");
9799
9800 case Instruction::AShr: {
9801 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9802 // bitwidth(K) iterations.
9803 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9804 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9805 Predecessor->getTerminator(), &DT);
9806 auto *Ty = cast<IntegerType>(RHS->getType());
9807 if (Known.isNonNegative())
9808 StableValue = ConstantInt::get(Ty, 0);
9809 else if (Known.isNegative())
9810 StableValue = ConstantInt::get(Ty, -1, true);
9811 else
9812 return getCouldNotCompute();
9813
9814 break;
9815 }
9816 case Instruction::LShr:
9817 case Instruction::Shl:
9818 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9819 // stabilize to 0 in at most bitwidth(K) iterations.
9820 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9821 break;
9822 }
9823
9824 auto *Result =
9825 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9826 assert(Result->getType()->isIntegerTy(1) &&
9827 "Otherwise cannot be an operand to a branch instruction");
9828
9829 if (Result->isNullValue()) {
9830 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9831 unsigned MaxBTC = BitWidth;
9832
9833 // For right-shift recurrences (lshr/ashr with non-negative start), we can
9834 // compute a tighter max backedge-taken count from the range of the start
9835 // value. After k shifts of ShiftAmt, value = start >> (k * ShiftAmt).
9836 // The value reaches 0 (the stable value) when k * ShiftAmt >=
9837 // activeBits(start), so max BTC = ceil(activeBits(maxStart) / ShiftAmt).
9838 if (OpCode == Instruction::LShr || OpCode == Instruction::AShr) {
9839 Value *StartValue = PN->getIncomingValueForBlock(Predecessor);
9840 const SCEV *StartSCEV = getSCEV(StartValue);
9841 APInt MaxStart = getUnsignedRangeMax(StartSCEV);
9842 if (MaxStart.isStrictlyPositive()) {
9843 unsigned ActiveBits = MaxStart.getActiveBits();
9844 unsigned RangeBTC = divideCeil(ActiveBits, ShiftAmt);
9845 MaxBTC = std::min(MaxBTC, RangeBTC);
9846 }
9847 }
9848
9849 const SCEV *UpperBound =
9851 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9852 }
9853
9854 return getCouldNotCompute();
9855}
9856
9857/// Return true if we can constant fold an instruction of the specified type,
9858/// assuming that all operands were constants.
9859static bool CanConstantFold(const Instruction *I) {
9863 return true;
9864
9865 if (const CallInst *CI = dyn_cast<CallInst>(I))
9866 if (const Function *F = CI->getCalledFunction())
9867 return canConstantFoldCallTo(CI, F);
9868 return false;
9869}
9870
9871/// Determine whether this instruction can constant evolve within this loop
9872/// assuming its operands can all constant evolve.
9873static bool canConstantEvolve(Instruction *I, const Loop *L) {
9874 // An instruction outside of the loop can't be derived from a loop PHI.
9875 if (!L->contains(I)) return false;
9876
9877 if (isa<PHINode>(I)) {
9878 // We don't currently keep track of the control flow needed to evaluate
9879 // PHIs, so we cannot handle PHIs inside of loops.
9880 return L->getHeader() == I->getParent();
9881 }
9882
9883 // If we won't be able to constant fold this expression even if the operands
9884 // are constants, bail early.
9885 return CanConstantFold(I);
9886}
9887
9888/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9889/// recursing through each instruction operand until reaching a loop header phi.
9890static PHINode *
9893 unsigned Depth) {
9895 return nullptr;
9896
9897 // Otherwise, we can evaluate this instruction if all of its operands are
9898 // constant or derived from a PHI node themselves.
9899 PHINode *PHI = nullptr;
9900 for (Value *Op : UseInst->operands()) {
9901 if (isa<Constant>(Op)) continue;
9902
9904 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9905
9906 PHINode *P = dyn_cast<PHINode>(OpInst);
9907 if (!P)
9908 // If this operand is already visited, reuse the prior result.
9909 // We may have P != PHI if this is the deepest point at which the
9910 // inconsistent paths meet.
9911 P = PHIMap.lookup(OpInst);
9912 if (!P) {
9913 // Recurse and memoize the results, whether a phi is found or not.
9914 // This recursive call invalidates pointers into PHIMap.
9915 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9916 PHIMap[OpInst] = P;
9917 }
9918 if (!P)
9919 return nullptr; // Not evolving from PHI
9920 if (PHI && PHI != P)
9921 return nullptr; // Evolving from multiple different PHIs.
9922 PHI = P;
9923 }
9924 // This is a expression evolving from a constant PHI!
9925 return PHI;
9926}
9927
9928/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9929/// in the loop that V is derived from. We allow arbitrary operations along the
9930/// way, but the operands of an operation must either be constants or a value
9931/// derived from a constant PHI. If this expression does not fit with these
9932/// constraints, return null.
9935 if (!I || !canConstantEvolve(I, L)) return nullptr;
9936
9937 if (PHINode *PN = dyn_cast<PHINode>(I))
9938 return PN;
9939
9940 // Record non-constant instructions contained by the loop.
9942 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9943}
9944
9945/// EvaluateExpression - Given an expression that passes the
9946/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9947/// in the loop has the value PHIVal. If we can't fold this expression for some
9948/// reason, return null.
9951 const DataLayout &DL,
9952 const TargetLibraryInfo *TLI) {
9953 // Convenient constant check, but redundant for recursive calls.
9954 if (Constant *C = dyn_cast<Constant>(V)) return C;
9956 if (!I) return nullptr;
9957
9958 if (Constant *C = Vals.lookup(I)) return C;
9959
9960 // An instruction inside the loop depends on a value outside the loop that we
9961 // weren't given a mapping for, or a value such as a call inside the loop.
9962 if (!canConstantEvolve(I, L)) return nullptr;
9963
9964 // An unmapped PHI can be due to a branch or another loop inside this loop,
9965 // or due to this not being the initial iteration through a loop where we
9966 // couldn't compute the evolution of this particular PHI last time.
9967 if (isa<PHINode>(I)) return nullptr;
9968
9969 std::vector<Constant*> Operands(I->getNumOperands());
9970
9971 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9972 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9973 if (!Operand) {
9974 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9975 if (!Operands[i]) return nullptr;
9976 continue;
9977 }
9978 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9979 Vals[Operand] = C;
9980 if (!C) return nullptr;
9981 Operands[i] = C;
9982 }
9983
9984 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9985 /*AllowNonDeterministic=*/false);
9986}
9987
9988
9989// If every incoming value to PN except the one for BB is a specific Constant,
9990// return that, else return nullptr.
9992 Constant *IncomingVal = nullptr;
9993
9994 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9995 if (PN->getIncomingBlock(i) == BB)
9996 continue;
9997
9998 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9999 if (!CurrentVal)
10000 return nullptr;
10001
10002 if (IncomingVal != CurrentVal) {
10003 if (IncomingVal)
10004 return nullptr;
10005 IncomingVal = CurrentVal;
10006 }
10007 }
10008
10009 return IncomingVal;
10010}
10011
10012/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
10013/// in the header of its containing loop, we know the loop executes a
10014/// constant number of times, and the PHI node is just a recurrence
10015/// involving constants, fold it.
10016Constant *
10017ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
10018 const APInt &BEs,
10019 const Loop *L) {
10020 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
10021 if (!Inserted)
10022 return I->second;
10023
10025 return nullptr; // Not going to evaluate it.
10026
10027 Constant *&RetVal = I->second;
10028
10029 DenseMap<Instruction *, Constant *> CurrentIterVals;
10030 BasicBlock *Header = L->getHeader();
10031 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10032
10033 BasicBlock *Latch = L->getLoopLatch();
10034 if (!Latch)
10035 return nullptr;
10036
10037 for (PHINode &PHI : Header->phis()) {
10038 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10039 CurrentIterVals[&PHI] = StartCST;
10040 }
10041 if (!CurrentIterVals.count(PN))
10042 return RetVal = nullptr;
10043
10044 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10045
10046 // Execute the loop symbolically to determine the exit value.
10047 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10048 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10049
10050 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10051 unsigned IterationNum = 0;
10052 const DataLayout &DL = getDataLayout();
10053 for (; ; ++IterationNum) {
10054 if (IterationNum == NumIterations)
10055 return RetVal = CurrentIterVals[PN]; // Got exit value!
10056
10057 // Compute the value of the PHIs for the next iteration.
10058 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10059 DenseMap<Instruction *, Constant *> NextIterVals;
10060 Constant *NextPHI =
10061 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10062 if (!NextPHI)
10063 return nullptr; // Couldn't evaluate!
10064 NextIterVals[PN] = NextPHI;
10065
10066 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10067
10068 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10069 // cease to be able to evaluate one of them or if they stop evolving,
10070 // because that doesn't necessarily prevent us from computing PN.
10072 for (const auto &I : CurrentIterVals) {
10073 PHINode *PHI = dyn_cast<PHINode>(I.first);
10074 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10075 PHIsToCompute.emplace_back(PHI, I.second);
10076 }
10077 // We use two distinct loops because EvaluateExpression may invalidate any
10078 // iterators into CurrentIterVals.
10079 for (const auto &I : PHIsToCompute) {
10080 PHINode *PHI = I.first;
10081 Constant *&NextPHI = NextIterVals[PHI];
10082 if (!NextPHI) { // Not already computed.
10083 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10084 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10085 }
10086 if (NextPHI != I.second)
10087 StoppedEvolving = false;
10088 }
10089
10090 // If all entries in CurrentIterVals == NextIterVals then we can stop
10091 // iterating, the loop can't continue to change.
10092 if (StoppedEvolving)
10093 return RetVal = CurrentIterVals[PN];
10094
10095 CurrentIterVals.swap(NextIterVals);
10096 }
10097}
10098
10099const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10100 Value *Cond,
10101 bool ExitWhen) {
10102 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10103 if (!PN) return getCouldNotCompute();
10104
10105 // If the loop is canonicalized, the PHI will have exactly two entries.
10106 // That's the only form we support here.
10107 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10108
10109 DenseMap<Instruction *, Constant *> CurrentIterVals;
10110 BasicBlock *Header = L->getHeader();
10111 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10112
10113 BasicBlock *Latch = L->getLoopLatch();
10114 assert(Latch && "Should follow from NumIncomingValues == 2!");
10115
10116 for (PHINode &PHI : Header->phis()) {
10117 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10118 CurrentIterVals[&PHI] = StartCST;
10119 }
10120 if (!CurrentIterVals.count(PN))
10121 return getCouldNotCompute();
10122
10123 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10124 // the loop symbolically to determine when the condition gets a value of
10125 // "ExitWhen".
10126 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10127 const DataLayout &DL = getDataLayout();
10128 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10129 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10130 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10131
10132 // Couldn't symbolically evaluate.
10133 if (!CondVal) return getCouldNotCompute();
10134
10135 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10136 ++NumBruteForceTripCountsComputed;
10137 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10138 }
10139
10140 // Update all the PHI nodes for the next iteration.
10141 DenseMap<Instruction *, Constant *> NextIterVals;
10142
10143 // Create a list of which PHIs we need to compute. We want to do this before
10144 // calling EvaluateExpression on them because that may invalidate iterators
10145 // into CurrentIterVals.
10146 SmallVector<PHINode *, 8> PHIsToCompute;
10147 for (const auto &I : CurrentIterVals) {
10148 PHINode *PHI = dyn_cast<PHINode>(I.first);
10149 if (!PHI || PHI->getParent() != Header) continue;
10150 PHIsToCompute.push_back(PHI);
10151 }
10152 for (PHINode *PHI : PHIsToCompute) {
10153 Constant *&NextPHI = NextIterVals[PHI];
10154 if (NextPHI) continue; // Already computed!
10155
10156 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10157 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10158 }
10159 CurrentIterVals.swap(NextIterVals);
10160 }
10161
10162 // Too many iterations were needed to evaluate.
10163 return getCouldNotCompute();
10164}
10165
10166const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10168 ValuesAtScopes[V];
10169 // Check to see if we've folded this expression at this loop before.
10170 for (auto &LS : Values)
10171 if (LS.first == L)
10172 return LS.second ? LS.second : V;
10173
10174 Values.emplace_back(L, nullptr);
10175
10176 // Otherwise compute it.
10177 const SCEV *C = computeSCEVAtScope(V, L);
10178 for (auto &LS : reverse(ValuesAtScopes[V]))
10179 if (LS.first == L) {
10180 LS.second = C;
10181 if (!isa<SCEVConstant>(C))
10182 ValuesAtScopesUsers[C].push_back({L, V});
10183 break;
10184 }
10185 return C;
10186}
10187
10188/// This builds up a Constant using the ConstantExpr interface. That way, we
10189/// will return Constants for objects which aren't represented by a
10190/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10191/// Returns NULL if the SCEV isn't representable as a Constant.
10193 switch (V->getSCEVType()) {
10194 case scCouldNotCompute:
10195 case scAddRecExpr:
10196 case scVScale:
10197 return nullptr;
10198 case scConstant:
10199 return cast<SCEVConstant>(V)->getValue();
10200 case scUnknown:
10202 case scPtrToAddr: {
10204 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10205 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10206
10207 return nullptr;
10208 }
10209 case scPtrToInt: {
10211 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10212 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10213
10214 return nullptr;
10215 }
10216 case scTruncate: {
10218 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10219 return ConstantExpr::getTrunc(CastOp, ST->getType());
10220 return nullptr;
10221 }
10222 case scAddExpr: {
10223 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10224 Constant *C = nullptr;
10225 for (const SCEV *Op : SA->operands()) {
10227 if (!OpC)
10228 return nullptr;
10229 if (!C) {
10230 C = OpC;
10231 continue;
10232 }
10233 assert(!C->getType()->isPointerTy() &&
10234 "Can only have one pointer, and it must be last");
10235 if (OpC->getType()->isPointerTy()) {
10236 // The offsets have been converted to bytes. We can add bytes using
10237 // an i8 GEP.
10238 C = ConstantExpr::getPtrAdd(OpC, C);
10239 } else {
10240 C = ConstantExpr::getAdd(C, OpC);
10241 }
10242 }
10243 return C;
10244 }
10245 case scMulExpr:
10246 case scSignExtend:
10247 case scZeroExtend:
10248 case scUDivExpr:
10249 case scSMaxExpr:
10250 case scUMaxExpr:
10251 case scSMinExpr:
10252 case scUMinExpr:
10254 return nullptr;
10255 }
10256 llvm_unreachable("Unknown SCEV kind!");
10257}
10258
10259const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10260 SmallVectorImpl<SCEVUse> &NewOps) {
10261 switch (S->getSCEVType()) {
10262 case scTruncate:
10263 case scZeroExtend:
10264 case scSignExtend:
10265 case scPtrToAddr:
10266 case scPtrToInt:
10267 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10268 case scAddRecExpr: {
10269 auto *AddRec = cast<SCEVAddRecExpr>(S);
10270 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10271 }
10272 case scAddExpr:
10273 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10274 case scMulExpr:
10275 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10276 case scUDivExpr:
10277 return getUDivExpr(NewOps[0], NewOps[1]);
10278 case scUMaxExpr:
10279 case scSMaxExpr:
10280 case scUMinExpr:
10281 case scSMinExpr:
10282 return getMinMaxExpr(S->getSCEVType(), NewOps);
10284 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10285 case scConstant:
10286 case scVScale:
10287 case scUnknown:
10288 return S;
10289 case scCouldNotCompute:
10290 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10291 }
10292 llvm_unreachable("Unknown SCEV kind!");
10293}
10294
10295const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10296 switch (V->getSCEVType()) {
10297 case scConstant:
10298 case scVScale:
10299 return V;
10300 case scAddRecExpr: {
10301 // If this is a loop recurrence for a loop that does not contain L, then we
10302 // are dealing with the final value computed by the loop.
10303 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10304 // First, attempt to evaluate each operand.
10305 // Avoid performing the look-up in the common case where the specified
10306 // expression has no loop-variant portions.
10307 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10308 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10309 if (OpAtScope == AddRec->getOperand(i))
10310 continue;
10311
10312 // Okay, at least one of these operands is loop variant but might be
10313 // foldable. Build a new instance of the folded commutative expression.
10315 NewOps.reserve(AddRec->getNumOperands());
10316 append_range(NewOps, AddRec->operands().take_front(i));
10317 NewOps.push_back(OpAtScope);
10318 for (++i; i != e; ++i)
10319 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10320
10321 const SCEV *FoldedRec = getAddRecExpr(
10322 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10323 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10324 // The addrec may be folded to a nonrecurrence, for example, if the
10325 // induction variable is multiplied by zero after constant folding. Go
10326 // ahead and return the folded value.
10327 if (!AddRec)
10328 return FoldedRec;
10329 break;
10330 }
10331
10332 // If the scope is outside the addrec's loop, evaluate it by using the
10333 // loop exit value of the addrec.
10334 if (!AddRec->getLoop()->contains(L)) {
10335 // To evaluate this recurrence, we need to know how many times the AddRec
10336 // loop iterates. Compute this now.
10337 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10338 if (BackedgeTakenCount == getCouldNotCompute())
10339 return AddRec;
10340
10341 // Then, evaluate the AddRec.
10342 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10343 }
10344
10345 return AddRec;
10346 }
10347 case scTruncate:
10348 case scZeroExtend:
10349 case scSignExtend:
10350 case scPtrToAddr:
10351 case scPtrToInt:
10352 case scAddExpr:
10353 case scMulExpr:
10354 case scUDivExpr:
10355 case scUMaxExpr:
10356 case scSMaxExpr:
10357 case scUMinExpr:
10358 case scSMinExpr:
10359 case scSequentialUMinExpr: {
10360 ArrayRef<SCEVUse> Ops = V->operands();
10361 // Avoid performing the look-up in the common case where the specified
10362 // expression has no loop-variant portions.
10363 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10364 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10365 if (OpAtScope != Ops[i].getPointer()) {
10366 // Okay, at least one of these operands is loop variant but might be
10367 // foldable. Build a new instance of the folded commutative expression.
10369 NewOps.reserve(Ops.size());
10370 append_range(NewOps, Ops.take_front(i));
10371 NewOps.push_back(OpAtScope);
10372
10373 for (++i; i != e; ++i) {
10374 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10375 NewOps.push_back(OpAtScope);
10376 }
10377
10378 return getWithOperands(V, NewOps);
10379 }
10380 }
10381 // If we got here, all operands are loop invariant.
10382 return V;
10383 }
10384 case scUnknown: {
10385 // If this instruction is evolved from a constant-evolving PHI, compute the
10386 // exit value from the loop without using SCEVs.
10387 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10389 if (!I)
10390 return V; // This is some other type of SCEVUnknown, just return it.
10391
10392 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10393 const Loop *CurrLoop = this->LI[I->getParent()];
10394 // Looking for loop exit value.
10395 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10396 PN->getParent() == CurrLoop->getHeader()) {
10397 // Okay, there is no closed form solution for the PHI node. Check
10398 // to see if the loop that contains it has a known backedge-taken
10399 // count. If so, we may be able to force computation of the exit
10400 // value.
10401 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10402 // This trivial case can show up in some degenerate cases where
10403 // the incoming IR has not yet been fully simplified.
10404 if (BackedgeTakenCount->isZero()) {
10405 Value *InitValue = nullptr;
10406 bool MultipleInitValues = false;
10407 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10408 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10409 if (!InitValue)
10410 InitValue = PN->getIncomingValue(i);
10411 else if (InitValue != PN->getIncomingValue(i)) {
10412 MultipleInitValues = true;
10413 break;
10414 }
10415 }
10416 }
10417 if (!MultipleInitValues && InitValue)
10418 return getSCEV(InitValue);
10419 }
10420 // Do we have a loop invariant value flowing around the backedge
10421 // for a loop which must execute the backedge?
10422 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10423 isKnownNonZero(BackedgeTakenCount) &&
10424 PN->getNumIncomingValues() == 2) {
10425
10426 unsigned InLoopPred =
10427 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10428 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10429 if (CurrLoop->isLoopInvariant(BackedgeVal))
10430 return getSCEV(BackedgeVal);
10431 }
10432 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10433 // Okay, we know how many times the containing loop executes. If
10434 // this is a constant evolving PHI node, get the final value at
10435 // the specified iteration number.
10436 Constant *RV =
10437 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10438 if (RV)
10439 return getSCEV(RV);
10440 }
10441 }
10442 }
10443
10444 // Okay, this is an expression that we cannot symbolically evaluate
10445 // into a SCEV. Check to see if it's possible to symbolically evaluate
10446 // the arguments into constants, and if so, try to constant propagate the
10447 // result. This is particularly useful for computing loop exit values.
10448 if (!CanConstantFold(I))
10449 return V; // This is some other type of SCEVUnknown, just return it.
10450
10451 SmallVector<Constant *, 4> Operands;
10452 Operands.reserve(I->getNumOperands());
10453 bool MadeImprovement = false;
10454 for (Value *Op : I->operands()) {
10455 if (Constant *C = dyn_cast<Constant>(Op)) {
10456 Operands.push_back(C);
10457 continue;
10458 }
10459
10460 // If any of the operands is non-constant and if they are
10461 // non-integer and non-pointer, don't even try to analyze them
10462 // with scev techniques.
10463 if (!isSCEVable(Op->getType()))
10464 return V;
10465
10466 const SCEV *OrigV = getSCEV(Op);
10467 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10468 MadeImprovement |= OrigV != OpV;
10469
10471 if (!C)
10472 return V;
10473 assert(C->getType() == Op->getType() && "Type mismatch");
10474 Operands.push_back(C);
10475 }
10476
10477 // Check to see if getSCEVAtScope actually made an improvement.
10478 if (!MadeImprovement)
10479 return V; // This is some other type of SCEVUnknown, just return it.
10480
10481 Constant *C = nullptr;
10482 const DataLayout &DL = getDataLayout();
10483 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10484 /*AllowNonDeterministic=*/false);
10485 if (!C)
10486 return V;
10487 return getSCEV(C);
10488 }
10489 case scCouldNotCompute:
10490 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10491 }
10492 llvm_unreachable("Unknown SCEV type!");
10493}
10494
10496 return getSCEVAtScope(getSCEV(V), L);
10497}
10498
10499const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10501 return stripInjectiveFunctions(ZExt->getOperand());
10503 return stripInjectiveFunctions(SExt->getOperand());
10504 return S;
10505}
10506
10507/// Finds the minimum unsigned root of the following equation:
10508///
10509/// A * X = B (mod N)
10510///
10511/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10512/// A and B isn't important.
10513///
10514/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10515static const SCEV *
10518 ScalarEvolution &SE, const Loop *L) {
10519 uint32_t BW = A.getBitWidth();
10520 assert(BW == SE.getTypeSizeInBits(B->getType()));
10521 assert(A != 0 && "A must be non-zero.");
10522
10523 // 1. D = gcd(A, N)
10524 //
10525 // The gcd of A and N may have only one prime factor: 2. The number of
10526 // trailing zeros in A is its multiplicity
10527 uint32_t Mult2 = A.countr_zero();
10528 // D = 2^Mult2
10529
10530 // 2. Check if B is divisible by D.
10531 //
10532 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10533 // is not less than multiplicity of this prime factor for D.
10534 unsigned MinTZ = SE.getMinTrailingZeros(B);
10535 // Try again with the terminator of the loop predecessor for context-specific
10536 // result, if MinTZ s too small.
10537 if (MinTZ < Mult2 && L->getLoopPredecessor())
10538 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10539 if (MinTZ < Mult2) {
10540 // Check if we can prove there's no remainder using URem.
10541 const SCEV *URem =
10542 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10543 const SCEV *Zero = SE.getZero(B->getType());
10544 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10545 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10546 if (!Predicates)
10547 return SE.getCouldNotCompute();
10548
10549 // Avoid adding a predicate that is known to be false.
10550 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10551 return SE.getCouldNotCompute();
10552 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10553 }
10554 }
10555
10556 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10557 // modulo (N / D).
10558 //
10559 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10560 // (N / D) in general. The inverse itself always fits into BW bits, though,
10561 // so we immediately truncate it.
10562 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10563 APInt I = AD.multiplicativeInverse().zext(BW);
10564
10565 // 4. Compute the minimum unsigned root of the equation:
10566 // I * (B / D) mod (N / D)
10567 // To simplify the computation, we factor out the divide by D:
10568 // (I * B mod N) / D
10569 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10570 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10571}
10572
10573/// For a given quadratic addrec, generate coefficients of the corresponding
10574/// quadratic equation, multiplied by a common value to ensure that they are
10575/// integers.
10576/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10577/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10578/// were multiplied by, and BitWidth is the bit width of the original addrec
10579/// coefficients.
10580/// This function returns std::nullopt if the addrec coefficients are not
10581/// compile- time constants.
10582static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10584 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10585 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10586 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10587 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10588 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10589 << *AddRec << '\n');
10590
10591 // We currently can only solve this if the coefficients are constants.
10592 if (!LC || !MC || !NC) {
10593 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10594 return std::nullopt;
10595 }
10596
10597 APInt L = LC->getAPInt();
10598 APInt M = MC->getAPInt();
10599 APInt N = NC->getAPInt();
10600 assert(!N.isZero() && "This is not a quadratic addrec");
10601
10602 unsigned BitWidth = LC->getAPInt().getBitWidth();
10603 unsigned NewWidth = BitWidth + 1;
10604 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10605 << BitWidth << '\n');
10606 // The sign-extension (as opposed to a zero-extension) here matches the
10607 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10608 N = N.sext(NewWidth);
10609 M = M.sext(NewWidth);
10610 L = L.sext(NewWidth);
10611
10612 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10613 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10614 // L+M, L+2M+N, L+3M+3N, ...
10615 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10616 //
10617 // The equation Acc = 0 is then
10618 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10619 // In a quadratic form it becomes:
10620 // N n^2 + (2M-N) n + 2L = 0.
10621
10622 APInt A = N;
10623 APInt B = 2 * M - A;
10624 APInt C = 2 * L;
10625 APInt T = APInt(NewWidth, 2);
10626 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10627 << "x + " << C << ", coeff bw: " << NewWidth
10628 << ", multiplied by " << T << '\n');
10629 return std::make_tuple(A, B, C, T, BitWidth);
10630}
10631
10632/// Helper function to compare optional APInts:
10633/// (a) if X and Y both exist, return min(X, Y),
10634/// (b) if neither X nor Y exist, return std::nullopt,
10635/// (c) if exactly one of X and Y exists, return that value.
10636static std::optional<APInt> MinOptional(std::optional<APInt> X,
10637 std::optional<APInt> Y) {
10638 if (X && Y) {
10639 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10640 APInt XW = X->sext(W);
10641 APInt YW = Y->sext(W);
10642 return XW.slt(YW) ? *X : *Y;
10643 }
10644 if (!X && !Y)
10645 return std::nullopt;
10646 return X ? *X : *Y;
10647}
10648
10649/// Helper function to truncate an optional APInt to a given BitWidth.
10650/// When solving addrec-related equations, it is preferable to return a value
10651/// that has the same bit width as the original addrec's coefficients. If the
10652/// solution fits in the original bit width, truncate it (except for i1).
10653/// Returning a value of a different bit width may inhibit some optimizations.
10654///
10655/// In general, a solution to a quadratic equation generated from an addrec
10656/// may require BW+1 bits, where BW is the bit width of the addrec's
10657/// coefficients. The reason is that the coefficients of the quadratic
10658/// equation are BW+1 bits wide (to avoid truncation when converting from
10659/// the addrec to the equation).
10660static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10661 unsigned BitWidth) {
10662 if (!X)
10663 return std::nullopt;
10664 unsigned W = X->getBitWidth();
10666 return X->trunc(BitWidth);
10667 return X;
10668}
10669
10670/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10671/// iterations. The values L, M, N are assumed to be signed, and they
10672/// should all have the same bit widths.
10673/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10674/// where BW is the bit width of the addrec's coefficients.
10675/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10676/// returned as such, otherwise the bit width of the returned value may
10677/// be greater than BW.
10678///
10679/// This function returns std::nullopt if
10680/// (a) the addrec coefficients are not constant, or
10681/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10682/// like x^2 = 5, no integer solutions exist, in other cases an integer
10683/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10684static std::optional<APInt>
10686 APInt A, B, C, M;
10687 unsigned BitWidth;
10688 auto T = GetQuadraticEquation(AddRec);
10689 if (!T)
10690 return std::nullopt;
10691
10692 std::tie(A, B, C, M, BitWidth) = *T;
10693 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10694 std::optional<APInt> X =
10696 if (!X)
10697 return std::nullopt;
10698
10699 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10700 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10701 if (!V->isZero())
10702 return std::nullopt;
10703
10704 return TruncIfPossible(X, BitWidth);
10705}
10706
10707/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10708/// iterations. The values M, N are assumed to be signed, and they
10709/// should all have the same bit widths.
10710/// Find the least n such that c(n) does not belong to the given range,
10711/// while c(n-1) does.
10712///
10713/// This function returns std::nullopt if
10714/// (a) the addrec coefficients are not constant, or
10715/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10716/// bounds of the range.
10717static std::optional<APInt>
10719 const ConstantRange &Range, ScalarEvolution &SE) {
10720 assert(AddRec->getOperand(0)->isZero() &&
10721 "Starting value of addrec should be 0");
10722 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10723 << Range << ", addrec " << *AddRec << '\n');
10724 // This case is handled in getNumIterationsInRange. Here we can assume that
10725 // we start in the range.
10726 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10727 "Addrec's initial value should be in range");
10728
10729 APInt A, B, C, M;
10730 unsigned BitWidth;
10731 auto T = GetQuadraticEquation(AddRec);
10732 if (!T)
10733 return std::nullopt;
10734
10735 // Be careful about the return value: there can be two reasons for not
10736 // returning an actual number. First, if no solutions to the equations
10737 // were found, and second, if the solutions don't leave the given range.
10738 // The first case means that the actual solution is "unknown", the second
10739 // means that it's known, but not valid. If the solution is unknown, we
10740 // cannot make any conclusions.
10741 // Return a pair: the optional solution and a flag indicating if the
10742 // solution was found.
10743 auto SolveForBoundary =
10744 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10745 // Solve for signed overflow and unsigned overflow, pick the lower
10746 // solution.
10747 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10748 << Bound << " (before multiplying by " << M << ")\n");
10749 Bound *= M; // The quadratic equation multiplier.
10750
10751 std::optional<APInt> SO;
10752 if (BitWidth > 1) {
10753 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10754 "signed overflow\n");
10756 }
10757 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10758 "unsigned overflow\n");
10759 std::optional<APInt> UO =
10761
10762 auto LeavesRange = [&] (const APInt &X) {
10763 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10764 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10765 if (Range.contains(V0->getValue()))
10766 return false;
10767 // X should be at least 1, so X-1 is non-negative.
10768 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10770 if (Range.contains(V1->getValue()))
10771 return true;
10772 return false;
10773 };
10774
10775 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10776 // can be a solution, but the function failed to find it. We cannot treat it
10777 // as "no solution".
10778 if (!SO || !UO)
10779 return {std::nullopt, false};
10780
10781 // Check the smaller value first to see if it leaves the range.
10782 // At this point, both SO and UO must have values.
10783 std::optional<APInt> Min = MinOptional(SO, UO);
10784 if (LeavesRange(*Min))
10785 return { Min, true };
10786 std::optional<APInt> Max = Min == SO ? UO : SO;
10787 if (LeavesRange(*Max))
10788 return { Max, true };
10789
10790 // Solutions were found, but were eliminated, hence the "true".
10791 return {std::nullopt, true};
10792 };
10793
10794 std::tie(A, B, C, M, BitWidth) = *T;
10795 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10796 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10797 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10798 auto SL = SolveForBoundary(Lower);
10799 auto SU = SolveForBoundary(Upper);
10800 // If any of the solutions was unknown, no meaninigful conclusions can
10801 // be made.
10802 if (!SL.second || !SU.second)
10803 return std::nullopt;
10804
10805 // Claim: The correct solution is not some value between Min and Max.
10806 //
10807 // Justification: Assuming that Min and Max are different values, one of
10808 // them is when the first signed overflow happens, the other is when the
10809 // first unsigned overflow happens. Crossing the range boundary is only
10810 // possible via an overflow (treating 0 as a special case of it, modeling
10811 // an overflow as crossing k*2^W for some k).
10812 //
10813 // The interesting case here is when Min was eliminated as an invalid
10814 // solution, but Max was not. The argument is that if there was another
10815 // overflow between Min and Max, it would also have been eliminated if
10816 // it was considered.
10817 //
10818 // For a given boundary, it is possible to have two overflows of the same
10819 // type (signed/unsigned) without having the other type in between: this
10820 // can happen when the vertex of the parabola is between the iterations
10821 // corresponding to the overflows. This is only possible when the two
10822 // overflows cross k*2^W for the same k. In such case, if the second one
10823 // left the range (and was the first one to do so), the first overflow
10824 // would have to enter the range, which would mean that either we had left
10825 // the range before or that we started outside of it. Both of these cases
10826 // are contradictions.
10827 //
10828 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10829 // solution is not some value between the Max for this boundary and the
10830 // Min of the other boundary.
10831 //
10832 // Justification: Assume that we had such Max_A and Min_B corresponding
10833 // to range boundaries A and B and such that Max_A < Min_B. If there was
10834 // a solution between Max_A and Min_B, it would have to be caused by an
10835 // overflow corresponding to either A or B. It cannot correspond to B,
10836 // since Min_B is the first occurrence of such an overflow. If it
10837 // corresponded to A, it would have to be either a signed or an unsigned
10838 // overflow that is larger than both eliminated overflows for A. But
10839 // between the eliminated overflows and this overflow, the values would
10840 // cover the entire value space, thus crossing the other boundary, which
10841 // is a contradiction.
10842
10843 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10844}
10845
10846ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10847 const Loop *L,
10848 bool ControlsOnlyExit,
10849 bool AllowPredicates) {
10850
10851 // This is only used for loops with a "x != y" exit test. The exit condition
10852 // is now expressed as a single expression, V = x-y. So the exit test is
10853 // effectively V != 0. We know and take advantage of the fact that this
10854 // expression only being used in a comparison by zero context.
10855
10857 // If the value is a constant
10858 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10859 // If the value is already zero, the branch will execute zero times.
10860 if (C->getValue()->isZero()) return C;
10861 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10862 }
10863
10864 const SCEVAddRecExpr *AddRec =
10865 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10866
10867 if (!AddRec && AllowPredicates)
10868 // Try to make this an AddRec using runtime tests, in the first X
10869 // iterations of this loop, where X is the SCEV expression found by the
10870 // algorithm below.
10871 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10872
10873 if (!AddRec || AddRec->getLoop() != L)
10874 return getCouldNotCompute();
10875
10876 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10877 // the quadratic equation to solve it.
10878 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10879 // We can only use this value if the chrec ends up with an exact zero
10880 // value at this index. When solving for "X*X != 5", for example, we
10881 // should not accept a root of 2.
10882 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10883 const auto *R = cast<SCEVConstant>(getConstant(*S));
10884 return ExitLimit(R, R, R, false, Predicates);
10885 }
10886 return getCouldNotCompute();
10887 }
10888
10889 // Otherwise we can only handle this if it is affine.
10890 if (!AddRec->isAffine())
10891 return getCouldNotCompute();
10892
10893 // If this is an affine expression, the execution count of this branch is
10894 // the minimum unsigned root of the following equation:
10895 //
10896 // Start + Step*N = 0 (mod 2^BW)
10897 //
10898 // equivalent to:
10899 //
10900 // Step*N = -Start (mod 2^BW)
10901 //
10902 // where BW is the common bit width of Start and Step.
10903
10904 // Get the initial value for the loop.
10905 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10906 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10907
10908 if (!isLoopInvariant(Step, L))
10909 return getCouldNotCompute();
10910
10911 LoopGuards Guards = LoopGuards::collect(L, *this);
10912 // Specialize step for this loop so we get context sensitive facts below.
10913 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10914
10915 // For positive steps (counting up until unsigned overflow):
10916 // N = -Start/Step (as unsigned)
10917 // For negative steps (counting down to zero):
10918 // N = Start/-Step
10919 // First compute the unsigned distance from zero in the direction of Step.
10920 bool CountDown = isKnownNegative(StepWLG);
10921 if (!CountDown && !isKnownNonNegative(StepWLG))
10922 return getCouldNotCompute();
10923
10924 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10925 // Handle unitary steps, which cannot wraparound.
10926 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10927 // N = Distance (as unsigned)
10928
10929 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10930 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10931 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10932
10933 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10934 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10935 // case, and see if we can improve the bound.
10936 //
10937 // Explicitly handling this here is necessary because getUnsignedRange
10938 // isn't context-sensitive; it doesn't know that we only care about the
10939 // range inside the loop.
10940 const SCEV *Zero = getZero(Distance->getType());
10941 const SCEV *One = getOne(Distance->getType());
10942 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10943 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10944 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10945 // as "unsigned_max(Distance + 1) - 1".
10946 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10947 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10948 }
10949 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10950 Predicates);
10951 }
10952
10953 // If the condition controls loop exit (the loop exits only if the expression
10954 // is true) and the addition is no-wrap we can use unsigned divide to
10955 // compute the backedge count. In this case, the step may not divide the
10956 // distance, but we don't care because if the condition is "missed" the loop
10957 // will have undefined behavior due to wrapping.
10958 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10959 loopHasNoAbnormalExits(AddRec->getLoop())) {
10960
10961 // If the stride is zero and the start is non-zero, the loop must be
10962 // infinite. In C++, most loops are finite by assumption, in which case the
10963 // step being zero implies UB must execute if the loop is entered.
10964 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10965 !isKnownNonZero(StepWLG))
10966 return getCouldNotCompute();
10967
10968 const SCEV *Exact =
10969 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10970 const SCEV *ConstantMax = getCouldNotCompute();
10971 if (Exact != getCouldNotCompute()) {
10972 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10973 ConstantMax =
10975 }
10976 const SCEV *SymbolicMax =
10977 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10978 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10979 }
10980
10981 // Solve the general equation.
10982 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10983 if (!StepC || StepC->getValue()->isZero())
10984 return getCouldNotCompute();
10985 const SCEV *E = SolveLinEquationWithOverflow(
10986 StepC->getAPInt(), getNegativeSCEV(Start),
10987 AllowPredicates ? &Predicates : nullptr, *this, L);
10988
10989 const SCEV *M = E;
10990 if (E != getCouldNotCompute()) {
10991 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10992 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10993 }
10994 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10995 return ExitLimit(E, M, S, false, Predicates);
10996}
10997
10998ScalarEvolution::ExitLimit
10999ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
11000 // Loops that look like: while (X == 0) are very strange indeed. We don't
11001 // handle them yet except for the trivial case. This could be expanded in the
11002 // future as needed.
11003
11004 // If the value is a constant, check to see if it is known to be non-zero
11005 // already. If so, the backedge will execute zero times.
11006 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
11007 if (!C->getValue()->isZero())
11008 return getZero(C->getType());
11009 return getCouldNotCompute(); // Otherwise it will loop infinitely.
11010 }
11011
11012 // We could implement others, but I really doubt anyone writes loops like
11013 // this, and if they did, they would already be constant folded.
11014 return getCouldNotCompute();
11015}
11016
11017std::pair<const BasicBlock *, const BasicBlock *>
11018ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
11019 const {
11020 // If the block has a unique predecessor, then there is no path from the
11021 // predecessor to the block that does not go through the direct edge
11022 // from the predecessor to the block.
11023 if (const BasicBlock *Pred = BB->getSinglePredecessor())
11024 return {Pred, BB};
11025
11026 // A loop's header is defined to be a block that dominates the loop.
11027 // If the header has a unique predecessor outside the loop, it must be
11028 // a block that has exactly one successor that can reach the loop.
11029 if (const Loop *L = LI.getLoopFor(BB))
11030 return {L->getLoopPredecessor(), L->getHeader()};
11031
11032 return {nullptr, BB};
11033}
11034
11035/// SCEV structural equivalence is usually sufficient for testing whether two
11036/// expressions are equal, however for the purposes of looking for a condition
11037/// guarding a loop, it can be useful to be a little more general, since a
11038/// front-end may have replicated the controlling expression.
11039static bool HasSameValue(const SCEV *A, const SCEV *B) {
11040 // Quick check to see if they are the same SCEV.
11041 if (A == B) return true;
11042
11043 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11044 // Not all instructions that are "identical" compute the same value. For
11045 // instance, two distinct alloca instructions allocating the same type are
11046 // identical and do not read memory; but compute distinct values.
11047 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11048 };
11049
11050 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11051 // two different instructions with the same value. Check for this case.
11052 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11053 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11054 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11055 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11056 if (ComputesEqualValues(AI, BI))
11057 return true;
11058
11059 // Otherwise assume they may have a different value.
11060 return false;
11061}
11062
11063static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11064 const SCEV *Op0, *Op1;
11065 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11066 return false;
11067 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11068 LHS = Op1;
11069 return true;
11070 }
11071 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11072 LHS = Op0;
11073 return true;
11074 }
11075 return false;
11076}
11077
11079 SCEVUse &RHS, unsigned Depth) {
11080 bool Changed = false;
11081 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11082 // '0 != 0'.
11083 auto TrivialCase = [&](bool TriviallyTrue) {
11085 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11086 return true;
11087 };
11088 // If we hit the max recursion limit bail out.
11089 if (Depth >= 3)
11090 return false;
11091
11092 const SCEV *NewLHS, *NewRHS;
11093 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11094 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11095 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11096 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11097
11098 // (X * vscale) pred (Y * vscale) ==> X pred Y
11099 // when both multiples are NSW.
11100 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11101 // when both multiples are NUW.
11102 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11103 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11104 !ICmpInst::isSigned(Pred))) {
11105 LHS = NewLHS;
11106 RHS = NewRHS;
11107 Changed = true;
11108 }
11109 }
11110
11111 // Canonicalize a constant to the right side.
11112 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11113 // Check for both operands constant.
11114 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11115 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11116 return TrivialCase(false);
11117 return TrivialCase(true);
11118 }
11119 // Otherwise swap the operands to put the constant on the right.
11120 std::swap(LHS, RHS);
11122 Changed = true;
11123 }
11124
11125 // (K + A) pred (K + B) --> A pred B
11126 // For equality, no flags are needed.
11127 // For signed, both adds must be NSW. For unsigned, both must be NUW.
11128 {
11129 const SCEVConstant *C = nullptr;
11130 if (match(LHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(NewLHS))) &&
11131 match(RHS, m_scev_Add(m_scev_Specific(C), m_SCEV(NewRHS)))) {
11132 const auto *LAdd = cast<SCEVAddExpr>(LHS);
11133 const auto *RAdd = cast<SCEVAddExpr>(RHS);
11134 if (ICmpInst::isEquality(Pred) ||
11135 (ICmpInst::isSigned(Pred) && LAdd->hasNoSignedWrap() &&
11136 RAdd->hasNoSignedWrap()) ||
11137 (ICmpInst::isUnsigned(Pred) && LAdd->hasNoUnsignedWrap() &&
11138 RAdd->hasNoUnsignedWrap())) {
11139 LHS = NewLHS;
11140 RHS = NewRHS;
11141 Changed = true;
11142 }
11143 }
11144 }
11145
11146 // (C * A) pred (C * B) --> A pred B
11147 // For equality predicates, both muls must be NUW or both must be NSW
11148 // (either suffices to make multiplication by C injective; C == 0 is
11149 // impossible because SCEV folds 0 * X to 0).
11150 // For signed ordering, C must be positive and both muls must be NSW.
11151 // For unsigned ordering, both muls must be NUW.
11152 {
11153 const SCEVConstant *C = nullptr;
11154 if (match(LHS, m_scev_Mul(m_SCEVConstant(C), m_SCEV(NewLHS))) &&
11155 match(RHS, m_scev_Mul(m_scev_Specific(C), m_SCEV(NewRHS)))) {
11156 const auto *LMul = cast<SCEVMulExpr>(LHS);
11157 const auto *RMul = cast<SCEVMulExpr>(RHS);
11158 bool BothNUW = LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap();
11159 bool BothNSW = LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap();
11160 if ((ICmpInst::isEquality(Pred) && (BothNUW || BothNSW)) ||
11161 (ICmpInst::isSigned(Pred) && BothNSW &&
11162 C->getAPInt().isStrictlyPositive()) ||
11163 (ICmpInst::isUnsigned(Pred) && BothNUW)) {
11164 LHS = NewLHS;
11165 RHS = NewRHS;
11166 Changed = true;
11167 }
11168 }
11169 }
11170
11171 // If we're comparing an addrec with a value which is loop-invariant in the
11172 // addrec's loop, put the addrec on the left. Also make a dominance check,
11173 // as both operands could be addrecs loop-invariant in each other's loop.
11174 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11175 const Loop *L = AR->getLoop();
11176 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11177 std::swap(LHS, RHS);
11179 Changed = true;
11180 }
11181 }
11182
11183 // If there's a constant operand, canonicalize comparisons with boundary
11184 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11185 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11186 const APInt &RA = RC->getAPInt();
11187
11188 bool SimplifiedByConstantRange = false;
11189
11190 if (!ICmpInst::isEquality(Pred)) {
11192 if (ExactCR.isFullSet())
11193 return TrivialCase(true);
11194 if (ExactCR.isEmptySet())
11195 return TrivialCase(false);
11196
11197 APInt NewRHS;
11198 CmpInst::Predicate NewPred;
11199 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11200 ICmpInst::isEquality(NewPred)) {
11201 // We were able to convert an inequality to an equality.
11202 Pred = NewPred;
11203 RHS = getConstant(NewRHS);
11204 Changed = SimplifiedByConstantRange = true;
11205 }
11206 }
11207
11208 if (!SimplifiedByConstantRange) {
11209 switch (Pred) {
11210 default:
11211 break;
11212 case ICmpInst::ICMP_EQ:
11213 case ICmpInst::ICMP_NE:
11214 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11215 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11216 Changed = true;
11217 break;
11218
11219 // The "Should have been caught earlier!" messages refer to the fact
11220 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11221 // should have fired on the corresponding cases, and canonicalized the
11222 // check to trivial case.
11223
11224 case ICmpInst::ICMP_UGE:
11225 assert(!RA.isMinValue() && "Should have been caught earlier!");
11226 Pred = ICmpInst::ICMP_UGT;
11227 RHS = getConstant(RA - 1);
11228 Changed = true;
11229 break;
11230 case ICmpInst::ICMP_ULE:
11231 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11232 Pred = ICmpInst::ICMP_ULT;
11233 RHS = getConstant(RA + 1);
11234 Changed = true;
11235 break;
11236 case ICmpInst::ICMP_SGE:
11237 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11238 Pred = ICmpInst::ICMP_SGT;
11239 RHS = getConstant(RA - 1);
11240 Changed = true;
11241 break;
11242 case ICmpInst::ICMP_SLE:
11243 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11244 Pred = ICmpInst::ICMP_SLT;
11245 RHS = getConstant(RA + 1);
11246 Changed = true;
11247 break;
11248 }
11249 }
11250 }
11251
11252 // Check for obvious equality.
11253 if (HasSameValue(LHS, RHS)) {
11254 if (ICmpInst::isTrueWhenEqual(Pred))
11255 return TrivialCase(true);
11257 return TrivialCase(false);
11258 }
11259
11260 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11261 // adding or subtracting 1 from one of the operands.
11262 switch (Pred) {
11263 case ICmpInst::ICMP_SLE:
11264 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11265 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11267 Pred = ICmpInst::ICMP_SLT;
11268 Changed = true;
11269 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11270 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11272 Pred = ICmpInst::ICMP_SLT;
11273 Changed = true;
11274 }
11275 break;
11276 case ICmpInst::ICMP_SGE:
11277 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11278 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11280 Pred = ICmpInst::ICMP_SGT;
11281 Changed = true;
11282 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11283 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11285 Pred = ICmpInst::ICMP_SGT;
11286 Changed = true;
11287 }
11288 break;
11289 case ICmpInst::ICMP_ULE:
11290 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11291 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11293 Pred = ICmpInst::ICMP_ULT;
11294 Changed = true;
11295 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11296 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11297 Pred = ICmpInst::ICMP_ULT;
11298 Changed = true;
11299 }
11300 break;
11301 case ICmpInst::ICMP_UGE:
11302 // If RHS is an op we can fold the -1, try that first.
11303 // Otherwise prefer LHS to preserve the nuw flag.
11304 if ((isa<SCEVConstant>(RHS) ||
11306 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11307 !getUnsignedRangeMin(RHS).isMinValue()) {
11308 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11309 Pred = ICmpInst::ICMP_UGT;
11310 Changed = true;
11311 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11312 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11314 Pred = ICmpInst::ICMP_UGT;
11315 Changed = true;
11316 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11317 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11318 Pred = ICmpInst::ICMP_UGT;
11319 Changed = true;
11320 }
11321 break;
11322 default:
11323 break;
11324 }
11325
11326 // TODO: More simplifications are possible here.
11327
11328 // Recursively simplify until we either hit a recursion limit or nothing
11329 // changes.
11330 if (Changed)
11331 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11332
11333 return Changed;
11334}
11335
11337 return getSignedRangeMax(S).isNegative();
11338}
11339
11343
11345 return !getSignedRangeMin(S).isNegative();
11346}
11347
11351
11353 // Query push down for cases where the unsigned range is
11354 // less than sufficient.
11355 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11356 return isKnownNonZero(SExt->getOperand(0));
11357 return getUnsignedRangeMin(S) != 0;
11358}
11359
11361 bool OrNegative) {
11362 auto NonRecursive = [OrNegative](const SCEV *S) {
11363 if (auto *C = dyn_cast<SCEVConstant>(S))
11364 return C->getAPInt().isPowerOf2() ||
11365 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11366
11367 // vscale is a power-of-two.
11368 return isa<SCEVVScale>(S);
11369 };
11370
11371 if (NonRecursive(S))
11372 return true;
11373
11374 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11375 if (!Mul)
11376 return false;
11377 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11378}
11379
11381 const SCEV *S, uint64_t M,
11383 if (M == 0)
11384 return false;
11385 if (M == 1)
11386 return true;
11387
11388 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11389 // starts with a multiple of M and at every iteration step S only adds
11390 // multiples of M.
11391 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11392 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11393 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11394
11395 // For a constant, check that "S % M == 0".
11396 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11397 APInt C = Cst->getAPInt();
11398 return C.urem(M) == 0;
11399 }
11400
11401 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11402
11403 // Basic tests have failed.
11404 // Check "S % M == 0" at compile time and record runtime Assumptions.
11405 auto *STy = dyn_cast<IntegerType>(S->getType());
11406 const SCEV *SmodM =
11407 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11408 const SCEV *Zero = getZero(STy);
11409
11410 // Check whether "S % M == 0" is known at compile time.
11411 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11412 return true;
11413
11414 // Check whether "S % M != 0" is known at compile time.
11415 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11416 return false;
11417
11419
11420 // Detect redundant predicates.
11421 for (auto *A : Assumptions)
11422 if (A->implies(P, *this))
11423 return true;
11424
11425 // Only record non-redundant predicates.
11426 Assumptions.push_back(P);
11427 return true;
11428}
11429
11431 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11433}
11434
11435std::pair<const SCEV *, const SCEV *>
11437 // Compute SCEV on entry of loop L.
11438 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11439 if (Start == getCouldNotCompute())
11440 return { Start, Start };
11441 // Compute post increment SCEV for loop L.
11442 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11443 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11444 return { Start, PostInc };
11445}
11446
11448 SCEVUse RHS) {
11449 // First collect all loops.
11451 getUsedLoops(LHS, LoopsUsed);
11452 getUsedLoops(RHS, LoopsUsed);
11453
11454 if (LoopsUsed.empty())
11455 return false;
11456
11457 // Domination relationship must be a linear order on collected loops.
11458#ifndef NDEBUG
11459 for (const auto *L1 : LoopsUsed)
11460 for (const auto *L2 : LoopsUsed)
11461 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11462 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11463 "Domination relationship is not a linear order");
11464#endif
11465
11466 const Loop *MDL =
11467 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11468 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11469 });
11470
11471 // Get init and post increment value for LHS.
11472 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11473 // if LHS contains unknown non-invariant SCEV then bail out.
11474 if (SplitLHS.first == getCouldNotCompute())
11475 return false;
11476 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11477 // Get init and post increment value for RHS.
11478 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11479 // if RHS contains unknown non-invariant SCEV then bail out.
11480 if (SplitRHS.first == getCouldNotCompute())
11481 return false;
11482 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11483 // It is possible that init SCEV contains an invariant load but it does
11484 // not dominate MDL and is not available at MDL loop entry, so we should
11485 // check it here.
11486 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11487 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11488 return false;
11489
11490 // It seems backedge guard check is faster than entry one so in some cases
11491 // it can speed up whole estimation by short circuit
11492 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11493 SplitRHS.second) &&
11494 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11495}
11496
11498 SCEVUse RHS) {
11499 // Canonicalize the inputs first.
11500 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11501
11502 if (isKnownViaInduction(Pred, LHS, RHS))
11503 return true;
11504
11505 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11506 return true;
11507
11508 // Otherwise see what can be done with some simple reasoning.
11509 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11510}
11511
11513 const SCEV *LHS,
11514 const SCEV *RHS) {
11515 if (isKnownPredicate(Pred, LHS, RHS))
11516 return true;
11518 return false;
11519 return std::nullopt;
11520}
11521
11523 const SCEV *RHS,
11524 const Instruction *CtxI) {
11525 // TODO: Analyze guards and assumes from Context's block.
11526 return isKnownPredicate(Pred, LHS, RHS) ||
11527 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11528}
11529
11530std::optional<bool>
11532 const SCEV *RHS, const Instruction *CtxI) {
11533 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11534 if (KnownWithoutContext)
11535 return KnownWithoutContext;
11536
11537 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11538 return true;
11540 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11541 return false;
11542 return std::nullopt;
11543}
11544
11546 const SCEVAddRecExpr *LHS,
11547 const SCEV *RHS) {
11548 const Loop *L = LHS->getLoop();
11549 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11550 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11551}
11552
11553std::optional<ScalarEvolution::MonotonicPredicateType>
11555 ICmpInst::Predicate Pred) {
11556 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11557
11558#ifndef NDEBUG
11559 // Verify an invariant: inverting the predicate should turn a monotonically
11560 // increasing change to a monotonically decreasing one, and vice versa.
11561 if (Result) {
11562 auto ResultSwapped =
11563 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11564
11565 assert(*ResultSwapped != *Result &&
11566 "monotonicity should flip as we flip the predicate");
11567 }
11568#endif
11569
11570 return Result;
11571}
11572
11573std::optional<ScalarEvolution::MonotonicPredicateType>
11574ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11575 ICmpInst::Predicate Pred) {
11576 // A zero step value for LHS means the induction variable is essentially a
11577 // loop invariant value. We don't really depend on the predicate actually
11578 // flipping from false to true (for increasing predicates, and the other way
11579 // around for decreasing predicates), all we care about is that *if* the
11580 // predicate changes then it only changes from false to true.
11581 //
11582 // A zero step value in itself is not very useful, but there may be places
11583 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11584 // as general as possible.
11585
11586 // Only handle LE/LT/GE/GT predicates.
11587 if (!ICmpInst::isRelational(Pred))
11588 return std::nullopt;
11589
11590 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11591 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11592 "Should be greater or less!");
11593
11594 // Check that AR does not wrap.
11595 if (ICmpInst::isUnsigned(Pred)) {
11596 if (!LHS->hasNoUnsignedWrap())
11597 return std::nullopt;
11599 }
11600 assert(ICmpInst::isSigned(Pred) &&
11601 "Relational predicate is either signed or unsigned!");
11602 if (!LHS->hasNoSignedWrap())
11603 return std::nullopt;
11604
11605 const SCEV *Step = LHS->getStepRecurrence(*this);
11606
11607 if (isKnownNonNegative(Step))
11609
11610 if (isKnownNonPositive(Step))
11612
11613 return std::nullopt;
11614}
11615
11616std::optional<ScalarEvolution::LoopInvariantPredicate>
11618 const SCEV *RHS, const Loop *L,
11619 const Instruction *CtxI) {
11620 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11621 if (!isLoopInvariant(RHS, L)) {
11622 if (!isLoopInvariant(LHS, L))
11623 return std::nullopt;
11624
11625 std::swap(LHS, RHS);
11627 }
11628
11629 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11630 if (!ArLHS || ArLHS->getLoop() != L)
11631 return std::nullopt;
11632
11633 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11634 if (!MonotonicType)
11635 return std::nullopt;
11636 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11637 // true as the loop iterates, and the backedge is control dependent on
11638 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11639 //
11640 // * if the predicate was false in the first iteration then the predicate
11641 // is never evaluated again, since the loop exits without taking the
11642 // backedge.
11643 // * if the predicate was true in the first iteration then it will
11644 // continue to be true for all future iterations since it is
11645 // monotonically increasing.
11646 //
11647 // For both the above possibilities, we can replace the loop varying
11648 // predicate with its value on the first iteration of the loop (which is
11649 // loop invariant).
11650 //
11651 // A similar reasoning applies for a monotonically decreasing predicate, by
11652 // replacing true with false and false with true in the above two bullets.
11654 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11655
11656 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11658 RHS);
11659
11660 if (!CtxI)
11661 return std::nullopt;
11662 // Try to prove via context.
11663 // TODO: Support other cases.
11664 switch (Pred) {
11665 default:
11666 break;
11667 case ICmpInst::ICMP_ULE:
11668 case ICmpInst::ICMP_ULT: {
11669 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11670 // Given preconditions
11671 // (1) ArLHS does not cross the border of positive and negative parts of
11672 // range because of:
11673 // - Positive step; (TODO: lift this limitation)
11674 // - nuw - does not cross zero boundary;
11675 // - nsw - does not cross SINT_MAX boundary;
11676 // (2) ArLHS <s RHS
11677 // (3) RHS >=s 0
11678 // we can replace the loop variant ArLHS <u RHS condition with loop
11679 // invariant Start(ArLHS) <u RHS.
11680 //
11681 // Because of (1) there are two options:
11682 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11683 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11684 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11685 // Because of (2) ArLHS <u RHS is trivially true.
11686 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11687 // We can strengthen this to Start(ArLHS) <u RHS.
11688 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11689 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11690 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11691 isKnownNonNegative(RHS) &&
11692 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11694 RHS);
11695 }
11696 }
11697
11698 return std::nullopt;
11699}
11700
11701std::optional<ScalarEvolution::LoopInvariantPredicate>
11703 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11704 const Instruction *CtxI, const SCEV *MaxIter) {
11706 Pred, LHS, RHS, L, CtxI, MaxIter))
11707 return LIP;
11708 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11709 // Number of iterations expressed as UMIN isn't always great for expressing
11710 // the value on the last iteration. If the straightforward approach didn't
11711 // work, try the following trick: if the a predicate is invariant for X, it
11712 // is also invariant for umin(X, ...). So try to find something that works
11713 // among subexpressions of MaxIter expressed as umin.
11714 for (SCEVUse Op : UMin->operands())
11716 Pred, LHS, RHS, L, CtxI, Op))
11717 return LIP;
11718 return std::nullopt;
11719}
11720
11721std::optional<ScalarEvolution::LoopInvariantPredicate>
11723 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11724 const Instruction *CtxI, const SCEV *MaxIter) {
11725 // Try to prove the following set of facts:
11726 // - The predicate is monotonic in the iteration space.
11727 // - If the check does not fail on the 1st iteration:
11728 // - No overflow will happen during first MaxIter iterations;
11729 // - It will not fail on the MaxIter'th iteration.
11730 // If the check does fail on the 1st iteration, we leave the loop and no
11731 // other checks matter.
11732
11733 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11734 if (!isLoopInvariant(RHS, L)) {
11735 if (!isLoopInvariant(LHS, L))
11736 return std::nullopt;
11737
11738 std::swap(LHS, RHS);
11740 }
11741
11742 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11743 if (!AR || AR->getLoop() != L)
11744 return std::nullopt;
11745
11746 // Even if both are valid, we need to consistently chose the unsigned or the
11747 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11748 // predicate.
11749 Pred = Pred.dropSameSign();
11750
11751 // The predicate must be relational (i.e. <, <=, >=, >).
11752 if (!ICmpInst::isRelational(Pred))
11753 return std::nullopt;
11754
11755 // TODO: Support steps other than +/- 1.
11756 const SCEV *Step = AR->getStepRecurrence(*this);
11757 auto *One = getOne(Step->getType());
11758 auto *MinusOne = getNegativeSCEV(One);
11759 if (Step != One && Step != MinusOne)
11760 return std::nullopt;
11761
11762 // Type mismatch here means that MaxIter is potentially larger than max
11763 // unsigned value in start type, which mean we cannot prove no wrap for the
11764 // indvar.
11765 if (AR->getType() != MaxIter->getType())
11766 return std::nullopt;
11767
11768 // Value of IV on suggested last iteration.
11769 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11770 // Does it still meet the requirement?
11771 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11772 return std::nullopt;
11773 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11774 // not exceed max unsigned value of this type), this effectively proves
11775 // that there is no wrap during the iteration. To prove that there is no
11776 // signed/unsigned wrap, we need to check that
11777 // Start <= Last for step = 1 or Start >= Last for step = -1.
11778 ICmpInst::Predicate NoOverflowPred =
11780 if (Step == MinusOne)
11781 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11782 const SCEV *Start = AR->getStart();
11783 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11784 return std::nullopt;
11785
11786 // Everything is fine.
11787 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11788}
11789
11790bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11791 SCEVUse LHS,
11792 SCEVUse RHS) {
11793 if (HasSameValue(LHS, RHS))
11794 return ICmpInst::isTrueWhenEqual(Pred);
11795
11796 auto CheckRange = [&](bool IsSigned) {
11797 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11798 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11799 return RangeLHS.icmp(Pred, RangeRHS);
11800 };
11801
11802 // The check at the top of the function catches the case where the values are
11803 // known to be equal.
11804 if (Pred == CmpInst::ICMP_EQ)
11805 return false;
11806
11807 if (Pred == CmpInst::ICMP_NE) {
11808 if (CheckRange(true) || CheckRange(false))
11809 return true;
11810 auto *Diff = getMinusSCEV(LHS, RHS);
11811 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11812 }
11813
11814 return CheckRange(CmpInst::isSigned(Pred));
11815}
11816
11817bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11819 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11820 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11821 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11822 // OutC1 and OutC2.
11823 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11824 APInt &OutC2,
11825 SCEV::NoWrapFlags ExpectedFlags) {
11826 SCEVUse XNonConstOp, XConstOp;
11827 SCEVUse YNonConstOp, YConstOp;
11828 SCEV::NoWrapFlags XFlagsPresent;
11829 SCEV::NoWrapFlags YFlagsPresent;
11830
11831 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11832 XConstOp = getZero(X->getType());
11833 XNonConstOp = X;
11834 XFlagsPresent = ExpectedFlags;
11835 }
11836 if (!isa<SCEVConstant>(XConstOp))
11837 return false;
11838
11839 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11840 YConstOp = getZero(Y->getType());
11841 YNonConstOp = Y;
11842 YFlagsPresent = ExpectedFlags;
11843 }
11844
11845 if (YNonConstOp != XNonConstOp)
11846 return false;
11847
11848 if (!isa<SCEVConstant>(YConstOp))
11849 return false;
11850
11851 // When matching ADDs with NUW flags (and unsigned predicates), only the
11852 // second ADD (with the larger constant) requires NUW.
11853 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11854 return false;
11855 if (ExpectedFlags != SCEV::FlagNUW &&
11856 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11857 return false;
11858 }
11859
11860 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11861 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11862
11863 return true;
11864 };
11865
11866 APInt C1;
11867 APInt C2;
11868
11869 switch (Pred) {
11870 default:
11871 break;
11872
11873 case ICmpInst::ICMP_SGE:
11874 std::swap(LHS, RHS);
11875 [[fallthrough]];
11876 case ICmpInst::ICMP_SLE:
11877 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11878 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11879 return true;
11880
11881 break;
11882
11883 case ICmpInst::ICMP_SGT:
11884 std::swap(LHS, RHS);
11885 [[fallthrough]];
11886 case ICmpInst::ICMP_SLT:
11887 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11888 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11889 return true;
11890
11891 break;
11892
11893 case ICmpInst::ICMP_UGE:
11894 std::swap(LHS, RHS);
11895 [[fallthrough]];
11896 case ICmpInst::ICMP_ULE:
11897 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11898 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11899 return true;
11900
11901 break;
11902
11903 case ICmpInst::ICMP_UGT:
11904 std::swap(LHS, RHS);
11905 [[fallthrough]];
11906 case ICmpInst::ICMP_ULT:
11907 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11908 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11909 return true;
11910 break;
11911 }
11912
11913 return false;
11914}
11915
11916bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11918 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11919 return false;
11920
11921 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11922 // the stack can result in exponential time complexity.
11923 SaveAndRestore Restore(ProvingSplitPredicate, true);
11924
11925 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11926 //
11927 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11928 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11929 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11930 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11931 // use isKnownPredicate later if needed.
11932 return isKnownNonNegative(RHS) &&
11935}
11936
11937bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11938 const SCEV *LHS, const SCEV *RHS) {
11939 // No need to even try if we know the module has no guards.
11940 if (!HasGuards)
11941 return false;
11942
11943 return any_of(*BB, [&](const Instruction &I) {
11944 using namespace llvm::PatternMatch;
11945
11946 Value *Condition;
11948 m_Value(Condition))) &&
11949 isImpliedCond(Pred, LHS, RHS, Condition, false);
11950 });
11951}
11952
11953/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11954/// protected by a conditional between LHS and RHS. This is used to
11955/// to eliminate casts.
11957 CmpPredicate Pred,
11958 const SCEV *LHS,
11959 const SCEV *RHS) {
11960 // Interpret a null as meaning no loop, where there is obviously no guard
11961 // (interprocedural conditions notwithstanding). Do not bother about
11962 // unreachable loops.
11963 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11964 return true;
11965
11966 if (VerifyIR)
11967 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11968 "This cannot be done on broken IR!");
11969
11970
11971 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11972 return true;
11973
11974 BasicBlock *Latch = L->getLoopLatch();
11975 if (!Latch)
11976 return false;
11977
11978 CondBrInst *LoopContinuePredicate =
11980 if (LoopContinuePredicate &&
11981 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11982 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11983 return true;
11984
11985 // We don't want more than one activation of the following loops on the stack
11986 // -- that can lead to O(n!) time complexity.
11987 if (WalkingBEDominatingConds)
11988 return false;
11989
11990 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11991
11992 // See if we can exploit a trip count to prove the predicate.
11993 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11994 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11995 if (LatchBECount != getCouldNotCompute()) {
11996 // We know that Latch branches back to the loop header exactly
11997 // LatchBECount times. This means the backdege condition at Latch is
11998 // equivalent to "{0,+,1} u< LatchBECount".
11999 Type *Ty = LatchBECount->getType();
12000 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
12001 const SCEV *LoopCounter =
12002 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
12003 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
12004 LatchBECount))
12005 return true;
12006 }
12007
12008 // Check conditions due to any @llvm.assume intrinsics.
12009 for (auto &AssumeVH : AC.assumptions()) {
12010 if (!AssumeVH)
12011 continue;
12012 auto *CI = cast<CallInst>(AssumeVH);
12013 if (!DT.dominates(CI, Latch->getTerminator()))
12014 continue;
12015
12016 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
12017 return true;
12018 }
12019
12020 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
12021 return true;
12022
12023 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
12024 DTN != HeaderDTN; DTN = DTN->getIDom()) {
12025 assert(DTN && "should reach the loop header before reaching the root!");
12026
12027 BasicBlock *BB = DTN->getBlock();
12028 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
12029 return true;
12030
12031 BasicBlock *PBB = BB->getSinglePredecessor();
12032 if (!PBB)
12033 continue;
12034
12036 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
12037 continue;
12038
12039 // If we have an edge `E` within the loop body that dominates the only
12040 // latch, the condition guarding `E` also guards the backedge. This
12041 // reasoning works only for loops with a single latch.
12042 // We're constructively (and conservatively) enumerating edges within the
12043 // loop body that dominate the latch. The dominator tree better agree
12044 // with us on this:
12045 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
12046 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
12047 BB != ContBr->getSuccessor(0)))
12048 return true;
12049 }
12050
12051 return false;
12052}
12053
12055 CmpPredicate Pred,
12056 const SCEV *LHS,
12057 const SCEV *RHS) {
12058 // Do not bother proving facts for unreachable code.
12059 if (!DT.isReachableFromEntry(BB))
12060 return true;
12061 if (VerifyIR)
12062 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
12063 "This cannot be done on broken IR!");
12064
12065 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
12066 // the facts (a >= b && a != b) separately. A typical situation is when the
12067 // non-strict comparison is known from ranges and non-equality is known from
12068 // dominating predicates. If we are proving strict comparison, we always try
12069 // to prove non-equality and non-strict comparison separately.
12070 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
12071 const bool ProvingStrictComparison =
12072 Pred != NonStrictPredicate.dropSameSign();
12073 bool ProvedNonStrictComparison = false;
12074 bool ProvedNonEquality = false;
12075
12076 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
12077 if (!ProvedNonStrictComparison)
12078 ProvedNonStrictComparison = Fn(NonStrictPredicate);
12079 if (!ProvedNonEquality)
12080 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
12081 if (ProvedNonStrictComparison && ProvedNonEquality)
12082 return true;
12083 return false;
12084 };
12085
12086 if (ProvingStrictComparison) {
12087 auto ProofFn = [&](CmpPredicate P) {
12088 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12089 };
12090 if (SplitAndProve(ProofFn))
12091 return true;
12092 }
12093
12094 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12095 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12096 const Instruction *CtxI = &BB->front();
12097 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12098 return true;
12099 if (ProvingStrictComparison) {
12100 auto ProofFn = [&](CmpPredicate P) {
12101 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12102 };
12103 if (SplitAndProve(ProofFn))
12104 return true;
12105 }
12106 return false;
12107 };
12108
12109 // Starting at the block's predecessor, climb up the predecessor chain, as long
12110 // as there are predecessors that can be found that have unique successors
12111 // leading to the original block.
12112 const Loop *ContainingLoop = LI.getLoopFor(BB);
12113 const BasicBlock *PredBB;
12114 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12115 PredBB = ContainingLoop->getLoopPredecessor();
12116 else
12117 PredBB = BB->getSinglePredecessor();
12118 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12119 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12120 const CondBrInst *BlockEntryPredicate =
12121 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12122 if (!BlockEntryPredicate)
12123 continue;
12124
12125 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12126 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12127 return true;
12128 }
12129
12130 // Check conditions due to any @llvm.assume intrinsics.
12131 for (auto &AssumeVH : AC.assumptions()) {
12132 if (!AssumeVH)
12133 continue;
12134 auto *CI = cast<CallInst>(AssumeVH);
12135 if (!DT.dominates(CI, BB))
12136 continue;
12137
12138 if (ProveViaCond(CI->getArgOperand(0), false))
12139 return true;
12140 }
12141
12142 // Check conditions due to any @llvm.experimental.guard intrinsics.
12143 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12144 F.getParent(), Intrinsic::experimental_guard);
12145 if (GuardDecl)
12146 for (const auto *GU : GuardDecl->users())
12147 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12148 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12149 if (ProveViaCond(Guard->getArgOperand(0), false))
12150 return true;
12151 return false;
12152}
12153
12155 const SCEV *LHS,
12156 const SCEV *RHS) {
12157 // Interpret a null as meaning no loop, where there is obviously no guard
12158 // (interprocedural conditions notwithstanding).
12159 if (!L)
12160 return false;
12161
12162 // Both LHS and RHS must be available at loop entry.
12164 "LHS is not available at Loop Entry");
12166 "RHS is not available at Loop Entry");
12167
12168 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12169 return true;
12170
12171 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12172}
12173
12174bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12175 const SCEV *RHS,
12176 const Value *FoundCondValue, bool Inverse,
12177 const Instruction *CtxI) {
12178 // False conditions implies anything. Do not bother analyzing it further.
12179 if (FoundCondValue ==
12180 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12181 return true;
12182
12183 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12184 return false;
12185
12186 llvm::scope_exit ClearOnExit(
12187 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12188
12189 // Recursively handle And and Or conditions.
12190 const Value *Op0, *Op1;
12191 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12192 if (!Inverse)
12193 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12194 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12195 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12196 if (Inverse)
12197 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12198 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12199 }
12200
12201 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12202 if (!ICI) return false;
12203
12204 // Now that we found a conditional branch that dominates the loop or controls
12205 // the loop latch. Check to see if it is the comparison we are looking for.
12206 CmpPredicate FoundPred;
12207 if (Inverse)
12208 FoundPred = ICI->getInverseCmpPredicate();
12209 else
12210 FoundPred = ICI->getCmpPredicate();
12211
12212 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12213 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12214
12215 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12216}
12217
12218bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12219 const SCEV *RHS, CmpPredicate FoundPred,
12220 const SCEV *FoundLHS, const SCEV *FoundRHS,
12221 const Instruction *CtxI) {
12222 // Balance the types.
12223 if (getTypeSizeInBits(LHS->getType()) <
12224 getTypeSizeInBits(FoundLHS->getType())) {
12225 // For unsigned and equality predicates, try to prove that both found
12226 // operands fit into narrow unsigned range. If so, try to prove facts in
12227 // narrow types.
12228 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12229 !FoundRHS->getType()->isPointerTy()) {
12230 auto *NarrowType = LHS->getType();
12231 auto *WideType = FoundLHS->getType();
12232 auto BitWidth = getTypeSizeInBits(NarrowType);
12233 const SCEV *MaxValue = getZeroExtendExpr(
12235 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12236 MaxValue) &&
12237 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12238 MaxValue)) {
12239 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12240 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12241 // We cannot preserve samesign after truncation.
12242 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12243 TruncFoundLHS, TruncFoundRHS, CtxI))
12244 return true;
12245 }
12246 }
12247
12248 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12249 return false;
12250 if (CmpInst::isSigned(Pred)) {
12251 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12252 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12253 } else {
12254 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12255 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12256 }
12257 } else if (getTypeSizeInBits(LHS->getType()) >
12258 getTypeSizeInBits(FoundLHS->getType())) {
12259 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12260 return false;
12261 if (CmpInst::isSigned(FoundPred)) {
12262 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12263 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12264 } else {
12265 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12266 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12267 }
12268 }
12269 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12270 FoundRHS, CtxI);
12271}
12272
12273bool ScalarEvolution::isImpliedCondBalancedTypes(
12274 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12275 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12277 getTypeSizeInBits(FoundLHS->getType()) &&
12278 "Types should be balanced!");
12279 // Canonicalize the query to match the way instcombine will have
12280 // canonicalized the comparison.
12281 if (SimplifyICmpOperands(Pred, LHS, RHS))
12282 if (LHS == RHS)
12283 return CmpInst::isTrueWhenEqual(Pred);
12284 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12285 if (FoundLHS == FoundRHS)
12286 return CmpInst::isFalseWhenEqual(FoundPred);
12287
12288 // Check to see if we can make the LHS or RHS match.
12289 if (LHS == FoundRHS || RHS == FoundLHS) {
12290 if (isa<SCEVConstant>(RHS)) {
12291 std::swap(FoundLHS, FoundRHS);
12292 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12293 } else {
12294 std::swap(LHS, RHS);
12296 }
12297 }
12298
12299 // Check whether the found predicate is the same as the desired predicate.
12300 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12301 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12302
12303 // Check whether swapping the found predicate makes it the same as the
12304 // desired predicate.
12305 if (auto P = CmpPredicate::getMatching(
12306 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12307 // We can write the implication
12308 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12309 // using one of the following ways:
12310 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12311 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12312 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12313 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12314 // Forms 1. and 2. require swapping the operands of one condition. Don't
12315 // do this if it would break canonical constant/addrec ordering.
12317 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12318 LHS, FoundLHS, FoundRHS, CtxI);
12319 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12320 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12321
12322 // There's no clear preference between forms 3. and 4., try both. Avoid
12323 // forming getNotSCEV of pointer values as the resulting subtract is
12324 // not legal.
12325 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12326 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12327 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12328 FoundRHS, CtxI))
12329 return true;
12330
12331 if (!FoundLHS->getType()->isPointerTy() &&
12332 !FoundRHS->getType()->isPointerTy() &&
12333 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12334 getNotSCEV(FoundRHS), CtxI))
12335 return true;
12336
12337 return false;
12338 }
12339
12340 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12342 assert(P1 != P2 && "Handled earlier!");
12343 return CmpInst::isRelational(P2) &&
12345 };
12346 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12347 // Unsigned comparison is the same as signed comparison when both the
12348 // operands are non-negative or negative.
12349 if (haveSameSign(FoundLHS, FoundRHS))
12350 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12351 // Create local copies that we can freely swap and canonicalize our
12352 // conditions to "le/lt".
12353 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12354 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12355 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12356 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12357 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12358 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12359 std::swap(CanonicalLHS, CanonicalRHS);
12360 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12361 }
12362 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12363 "Must be!");
12364 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12365 ICmpInst::isLE(CanonicalFoundPred)) &&
12366 "Must be!");
12367 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12368 // Use implication:
12369 // x <u y && y >=s 0 --> x <s y.
12370 // If we can prove the left part, the right part is also proven.
12371 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12372 CanonicalRHS, CanonicalFoundLHS,
12373 CanonicalFoundRHS);
12374 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12375 // Use implication:
12376 // x <s y && y <s 0 --> x <u y.
12377 // If we can prove the left part, the right part is also proven.
12378 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12379 CanonicalRHS, CanonicalFoundLHS,
12380 CanonicalFoundRHS);
12381 }
12382
12383 // Check if we can make progress by sharpening ranges.
12384 if (FoundPred == ICmpInst::ICMP_NE &&
12385 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12386
12387 const SCEVConstant *C = nullptr;
12388 const SCEV *V = nullptr;
12389
12390 if (isa<SCEVConstant>(FoundLHS)) {
12391 C = cast<SCEVConstant>(FoundLHS);
12392 V = FoundRHS;
12393 } else {
12394 C = cast<SCEVConstant>(FoundRHS);
12395 V = FoundLHS;
12396 }
12397
12398 // The guarding predicate tells us that C != V. If the known range
12399 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12400 // range we consider has to correspond to same signedness as the
12401 // predicate we're interested in folding.
12402
12403 APInt Min = ICmpInst::isSigned(Pred) ?
12405
12406 if (Min == C->getAPInt()) {
12407 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12408 // This is true even if (Min + 1) wraps around -- in case of
12409 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12410
12411 APInt SharperMin = Min + 1;
12412
12413 switch (Pred) {
12414 case ICmpInst::ICMP_SGE:
12415 case ICmpInst::ICMP_UGE:
12416 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12417 // RHS, we're done.
12418 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12419 CtxI))
12420 return true;
12421 [[fallthrough]];
12422
12423 case ICmpInst::ICMP_SGT:
12424 case ICmpInst::ICMP_UGT:
12425 // We know from the range information that (V `Pred` Min ||
12426 // V == Min). We know from the guarding condition that !(V
12427 // == Min). This gives us
12428 //
12429 // V `Pred` Min || V == Min && !(V == Min)
12430 // => V `Pred` Min
12431 //
12432 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12433
12434 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12435 return true;
12436 break;
12437
12438 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12439 case ICmpInst::ICMP_SLE:
12440 case ICmpInst::ICMP_ULE:
12441 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12442 LHS, V, getConstant(SharperMin), CtxI))
12443 return true;
12444 [[fallthrough]];
12445
12446 case ICmpInst::ICMP_SLT:
12447 case ICmpInst::ICMP_ULT:
12448 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12449 LHS, V, getConstant(Min), CtxI))
12450 return true;
12451 break;
12452
12453 default:
12454 // No change
12455 break;
12456 }
12457 }
12458 }
12459
12460 // Check whether the actual condition is beyond sufficient.
12461 if (FoundPred == ICmpInst::ICMP_EQ)
12462 if (ICmpInst::isTrueWhenEqual(Pred))
12463 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12464 return true;
12465 if (Pred == ICmpInst::ICMP_NE)
12466 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12467 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12468 return true;
12469
12470 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12471 return true;
12472
12473 // Otherwise assume the worst.
12474 return false;
12475}
12476
12477bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12478 SCEV::NoWrapFlags &Flags) {
12479 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12480 return false;
12481
12482 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12483 return true;
12484}
12485
12486std::optional<APInt>
12488 // We avoid subtracting expressions here because this function is usually
12489 // fairly deep in the call stack (i.e. is called many times).
12490
12491 unsigned BW = getTypeSizeInBits(More->getType());
12492 APInt Diff(BW, 0);
12493 APInt DiffMul(BW, 1);
12494 // Try various simplifications to reduce the difference to a constant. Limit
12495 // the number of allowed simplifications to keep compile-time low.
12496 for (unsigned I = 0; I < 8; ++I) {
12497 if (More == Less)
12498 return Diff;
12499
12500 // Reduce addrecs with identical steps to their start value.
12502 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12503 const auto *MAR = cast<SCEVAddRecExpr>(More);
12504
12505 if (LAR->getLoop() != MAR->getLoop())
12506 return std::nullopt;
12507
12508 // We look at affine expressions only; not for correctness but to keep
12509 // getStepRecurrence cheap.
12510 if (!LAR->isAffine() || !MAR->isAffine())
12511 return std::nullopt;
12512
12513 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12514 return std::nullopt;
12515
12516 Less = LAR->getStart();
12517 More = MAR->getStart();
12518 continue;
12519 }
12520
12521 // Try to match a common constant multiply.
12522 auto MatchConstMul =
12523 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12524 const APInt *C;
12525 const SCEV *Op;
12526 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12527 return {{Op, *C}};
12528 return std::nullopt;
12529 };
12530 if (auto MatchedMore = MatchConstMul(More)) {
12531 if (auto MatchedLess = MatchConstMul(Less)) {
12532 if (MatchedMore->second == MatchedLess->second) {
12533 More = MatchedMore->first;
12534 Less = MatchedLess->first;
12535 DiffMul *= MatchedMore->second;
12536 continue;
12537 }
12538 }
12539 }
12540
12541 // Try to cancel out common factors in two add expressions.
12543 auto Add = [&](const SCEV *S, int Mul) {
12544 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12545 if (Mul == 1) {
12546 Diff += C->getAPInt() * DiffMul;
12547 } else {
12548 assert(Mul == -1);
12549 Diff -= C->getAPInt() * DiffMul;
12550 }
12551 } else
12552 Multiplicity[S] += Mul;
12553 };
12554 auto Decompose = [&](const SCEV *S, int Mul) {
12555 if (isa<SCEVAddExpr>(S)) {
12556 for (const SCEV *Op : S->operands())
12557 Add(Op, Mul);
12558 } else
12559 Add(S, Mul);
12560 };
12561 Decompose(More, 1);
12562 Decompose(Less, -1);
12563
12564 // Check whether all the non-constants cancel out, or reduce to new
12565 // More/Less values.
12566 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12567 for (const auto &[S, Mul] : Multiplicity) {
12568 if (Mul == 0)
12569 continue;
12570 if (Mul == 1) {
12571 if (NewMore)
12572 return std::nullopt;
12573 NewMore = S;
12574 } else if (Mul == -1) {
12575 if (NewLess)
12576 return std::nullopt;
12577 NewLess = S;
12578 } else
12579 return std::nullopt;
12580 }
12581
12582 // Values stayed the same, no point in trying further.
12583 if (NewMore == More || NewLess == Less)
12584 return std::nullopt;
12585
12586 More = NewMore;
12587 Less = NewLess;
12588
12589 // Reduced to constant.
12590 if (!More && !Less)
12591 return Diff;
12592
12593 // Left with variable on only one side, bail out.
12594 if (!More || !Less)
12595 return std::nullopt;
12596 }
12597
12598 // Did not reduce to constant.
12599 return std::nullopt;
12600}
12601
12602bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12603 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12604 const SCEV *FoundRHS, const Instruction *CtxI) {
12605 // Try to recognize the following pattern:
12606 //
12607 // FoundRHS = ...
12608 // ...
12609 // loop:
12610 // FoundLHS = {Start,+,W}
12611 // context_bb: // Basic block from the same loop
12612 // known(Pred, FoundLHS, FoundRHS)
12613 //
12614 // If some predicate is known in the context of a loop, it is also known on
12615 // each iteration of this loop, including the first iteration. Therefore, in
12616 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12617 // prove the original pred using this fact.
12618 if (!CtxI)
12619 return false;
12620 const BasicBlock *ContextBB = CtxI->getParent();
12621 // Make sure AR varies in the context block.
12622 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12623 const Loop *L = AR->getLoop();
12624 const auto *Latch = L->getLoopLatch();
12625 // Make sure that context belongs to the loop and executes on 1st iteration
12626 // (if it ever executes at all).
12627 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12628 return false;
12629 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12630 return false;
12631 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12632 }
12633
12634 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12635 const Loop *L = AR->getLoop();
12636 const auto *Latch = L->getLoopLatch();
12637 // Make sure that context belongs to the loop and executes on 1st iteration
12638 // (if it ever executes at all).
12639 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12640 return false;
12641 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12642 return false;
12643 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12644 }
12645
12646 return false;
12647}
12648
12649bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12650 const SCEV *LHS,
12651 const SCEV *RHS,
12652 const SCEV *FoundLHS,
12653 const SCEV *FoundRHS) {
12654 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12655 return false;
12656
12657 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12658 if (!AddRecLHS)
12659 return false;
12660
12661 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12662 if (!AddRecFoundLHS)
12663 return false;
12664
12665 // We'd like to let SCEV reason about control dependencies, so we constrain
12666 // both the inequalities to be about add recurrences on the same loop. This
12667 // way we can use isLoopEntryGuardedByCond later.
12668
12669 const Loop *L = AddRecFoundLHS->getLoop();
12670 if (L != AddRecLHS->getLoop())
12671 return false;
12672
12673 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12674 //
12675 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12676 // ... (2)
12677 //
12678 // Informal proof for (2), assuming (1) [*]:
12679 //
12680 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12681 //
12682 // Then
12683 //
12684 // FoundLHS s< FoundRHS s< INT_MIN - C
12685 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12686 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12687 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12688 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12689 // <=> FoundLHS + C s< FoundRHS + C
12690 //
12691 // [*]: (1) can be proved by ruling out overflow.
12692 //
12693 // [**]: This can be proved by analyzing all the four possibilities:
12694 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12695 // (A s>= 0, B s>= 0).
12696 //
12697 // Note:
12698 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12699 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12700 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12701 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12702 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12703 // C)".
12704
12705 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12706 if (!LDiff)
12707 return false;
12708 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12709 if (!RDiff || *LDiff != *RDiff)
12710 return false;
12711
12712 if (LDiff->isMinValue())
12713 return true;
12714
12715 APInt FoundRHSLimit;
12716
12717 if (Pred == CmpInst::ICMP_ULT) {
12718 FoundRHSLimit = -(*RDiff);
12719 } else {
12720 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12721 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12722 }
12723
12724 // Try to prove (1) or (2), as needed.
12725 return isAvailableAtLoopEntry(FoundRHS, L) &&
12726 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12727 getConstant(FoundRHSLimit));
12728}
12729
12730bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12731 const SCEV *RHS, const SCEV *FoundLHS,
12732 const SCEV *FoundRHS, unsigned Depth) {
12733 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12734
12735 llvm::scope_exit ClearOnExit([&]() {
12736 if (LPhi) {
12737 bool Erased = PendingMerges.erase(LPhi);
12738 assert(Erased && "Failed to erase LPhi!");
12739 (void)Erased;
12740 }
12741 if (RPhi) {
12742 bool Erased = PendingMerges.erase(RPhi);
12743 assert(Erased && "Failed to erase RPhi!");
12744 (void)Erased;
12745 }
12746 });
12747
12748 // Find respective Phis and check that they are not being pending.
12749 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12750 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12751 if (!PendingMerges.insert(Phi).second)
12752 return false;
12753 LPhi = Phi;
12754 }
12755 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12756 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12757 // If we detect a loop of Phi nodes being processed by this method, for
12758 // example:
12759 //
12760 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12761 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12762 //
12763 // we don't want to deal with a case that complex, so return conservative
12764 // answer false.
12765 if (!PendingMerges.insert(Phi).second)
12766 return false;
12767 RPhi = Phi;
12768 }
12769
12770 // If none of LHS, RHS is a Phi, nothing to do here.
12771 if (!LPhi && !RPhi)
12772 return false;
12773
12774 // If there is a SCEVUnknown Phi we are interested in, make it left.
12775 if (!LPhi) {
12776 std::swap(LHS, RHS);
12777 std::swap(FoundLHS, FoundRHS);
12778 std::swap(LPhi, RPhi);
12780 }
12781
12782 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12783 const BasicBlock *LBB = LPhi->getParent();
12784 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12785
12786 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12787 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12788 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12789 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12790 };
12791
12792 if (RPhi && RPhi->getParent() == LBB) {
12793 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12794 // If we compare two Phis from the same block, and for each entry block
12795 // the predicate is true for incoming values from this block, then the
12796 // predicate is also true for the Phis.
12797 for (const BasicBlock *IncBB : predecessors(LBB)) {
12798 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12799 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12800 if (!ProvedEasily(L, R))
12801 return false;
12802 }
12803 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12804 // Case two: RHS is also a Phi from the same basic block, and it is an
12805 // AddRec. It means that there is a loop which has both AddRec and Unknown
12806 // PHIs, for it we can compare incoming values of AddRec from above the loop
12807 // and latch with their respective incoming values of LPhi.
12808 // TODO: Generalize to handle loops with many inputs in a header.
12809 if (LPhi->getNumIncomingValues() != 2) return false;
12810
12811 auto *RLoop = RAR->getLoop();
12812 auto *Predecessor = RLoop->getLoopPredecessor();
12813 assert(Predecessor && "Loop with AddRec with no predecessor?");
12814 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12815 if (!ProvedEasily(L1, RAR->getStart()))
12816 return false;
12817 auto *Latch = RLoop->getLoopLatch();
12818 assert(Latch && "Loop with AddRec with no latch?");
12819 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12820 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12821 return false;
12822 } else {
12823 // In all other cases go over inputs of LHS and compare each of them to RHS,
12824 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12825 // At this point RHS is either a non-Phi, or it is a Phi from some block
12826 // different from LBB.
12827 for (const BasicBlock *IncBB : predecessors(LBB)) {
12828 // Check that RHS is available in this block.
12829 if (!dominates(RHS, IncBB))
12830 return false;
12831 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12832 // Make sure L does not refer to a value from a potentially previous
12833 // iteration of a loop.
12834 if (!properlyDominates(L, LBB))
12835 return false;
12836 // Addrecs are considered to properly dominate their loop, so are missed
12837 // by the previous check. Discard any values that have computable
12838 // evolution in this loop.
12839 if (auto *Loop = LI.getLoopFor(LBB))
12840 if (hasComputableLoopEvolution(L, Loop))
12841 return false;
12842 if (!ProvedEasily(L, RHS))
12843 return false;
12844 }
12845 }
12846 return true;
12847}
12848
12849bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12850 const SCEV *LHS,
12851 const SCEV *RHS,
12852 const SCEV *FoundLHS,
12853 const SCEV *FoundRHS) {
12854 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12855 // sure that we are dealing with same LHS.
12856 if (RHS == FoundRHS) {
12857 std::swap(LHS, RHS);
12858 std::swap(FoundLHS, FoundRHS);
12860 }
12861 if (LHS != FoundLHS)
12862 return false;
12863
12864 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12865 if (!SUFoundRHS)
12866 return false;
12867
12868 Value *Shiftee, *ShiftValue;
12869
12870 using namespace PatternMatch;
12871 if (match(SUFoundRHS->getValue(),
12872 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12873 auto *ShifteeS = getSCEV(Shiftee);
12874 // Prove one of the following:
12875 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12876 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12877 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12878 // ---> LHS <s RHS
12879 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12880 // ---> LHS <=s RHS
12881 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12882 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12883 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12884 if (isKnownNonNegative(ShifteeS))
12885 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12886 }
12887
12888 return false;
12889}
12890
12891bool ScalarEvolution::isImpliedCondOperandsViaMatchingDiff(
12892 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12893 const SCEV *FoundRHS) {
12894 // Only valid for equality predicates: (A == B) implies (C == D) when
12895 // the SCEV difference A - B equals C - D (they check the same
12896 // underlying relationship at every iteration).
12897 if (!ICmpInst::isEquality(Pred))
12898 return false;
12899
12900 // Restrict to cases involving loop recurrences - that's where this
12901 // pattern arises (correlated IV comparisons). This avoids calling
12902 // getMinusSCEV on arbitrary non-loop expressions.
12904 (!isa<SCEVAddRecExpr>(FoundLHS) && !isa<SCEVAddRecExpr>(FoundRHS)))
12905 return false;
12906
12907 // AddRecs from different loops can never produce matching differences.
12908 const SCEVAddRecExpr *QueryAddRec = dyn_cast<SCEVAddRecExpr>(LHS);
12909 if (!QueryAddRec)
12910 QueryAddRec = cast<SCEVAddRecExpr>(RHS);
12911 const SCEVAddRecExpr *FoundAddRec = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12912 if (!FoundAddRec)
12913 FoundAddRec = cast<SCEVAddRecExpr>(FoundRHS);
12914 if (QueryAddRec->getLoop() != FoundAddRec->getLoop())
12915 return false;
12916
12917 // If the strides differ, the differences can never match.
12918 if (QueryAddRec->getStepRecurrence(*this) !=
12919 FoundAddRec->getStepRecurrence(*this))
12920 return false;
12921
12922 // Compute differences. For pointer-typed operands sharing the same base,
12923 // getMinusSCEV strips the common base and returns an integer SCEV.
12924 // For example, {base,+,8} - (base+8*n) = {-8n,+,8}
12925 const SCEV *FoundDiff = getMinusSCEV(FoundLHS, FoundRHS);
12926 if (isa<SCEVCouldNotCompute>(FoundDiff))
12927 return false;
12928
12929 const SCEV *Diff = getMinusSCEV(LHS, RHS);
12930 if (isa<SCEVCouldNotCompute>(Diff))
12931 return false;
12932
12933 return Diff == FoundDiff;
12934}
12935
12936bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12937 const SCEV *RHS,
12938 const SCEV *FoundLHS,
12939 const SCEV *FoundRHS,
12940 const Instruction *CtxI) {
12941 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12942 FoundRHS) ||
12943 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12944 FoundRHS) ||
12945 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12946 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12947 CtxI) ||
12948 isImpliedCondOperandsViaMatchingDiff(Pred, LHS, RHS, FoundLHS,
12949 FoundRHS) ||
12950 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12951}
12952
12953/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12954template <typename MinMaxExprType>
12955static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12956 const SCEV *Candidate) {
12957 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12958 if (!MinMaxExpr)
12959 return false;
12960
12961 return is_contained(MinMaxExpr->operands(), Candidate);
12962}
12963
12965 CmpPredicate Pred, const SCEV *LHS,
12966 const SCEV *RHS) {
12967 // If both sides are affine addrecs for the same loop, with equal
12968 // steps, and we know the recurrences don't wrap, then we only
12969 // need to check the predicate on the starting values.
12970
12971 if (!ICmpInst::isRelational(Pred))
12972 return false;
12973
12974 const SCEV *LStart, *RStart, *Step;
12975 const Loop *L;
12976 if (!match(LHS,
12977 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12979 m_SpecificLoop(L))))
12980 return false;
12985 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12986 return false;
12987
12988 return SE.isKnownPredicate(Pred, LStart, RStart);
12989}
12990
12991/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12992/// expression?
12994 const SCEV *LHS, const SCEV *RHS) {
12995 switch (Pred) {
12996 default:
12997 return false;
12998
12999 case ICmpInst::ICMP_SGE:
13000 std::swap(LHS, RHS);
13001 [[fallthrough]];
13002 case ICmpInst::ICMP_SLE:
13003 return
13004 // min(A, ...) <= A
13006 // A <= max(A, ...)
13008
13009 case ICmpInst::ICMP_UGE:
13010 std::swap(LHS, RHS);
13011 [[fallthrough]];
13012 case ICmpInst::ICMP_ULE:
13013 return
13014 // min(A, ...) <= A
13015 // FIXME: what about umin_seq?
13017 // A <= max(A, ...)
13019 }
13020
13021 llvm_unreachable("covered switch fell through?!");
13022}
13023
13024bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
13025 const SCEV *RHS,
13026 const SCEV *FoundLHS,
13027 const SCEV *FoundRHS,
13028 unsigned Depth) {
13031 "LHS and RHS have different sizes?");
13032 assert(getTypeSizeInBits(FoundLHS->getType()) ==
13033 getTypeSizeInBits(FoundRHS->getType()) &&
13034 "FoundLHS and FoundRHS have different sizes?");
13035 // We want to avoid hurting the compile time with analysis of too big trees.
13037 return false;
13038
13039 // We only want to work with GT comparison so far.
13040 if (ICmpInst::isLT(Pred)) {
13042 std::swap(LHS, RHS);
13043 std::swap(FoundLHS, FoundRHS);
13044 }
13045
13047
13048 // For unsigned, try to reduce it to corresponding signed comparison.
13049 if (P == ICmpInst::ICMP_UGT)
13050 // We can replace unsigned predicate with its signed counterpart if all
13051 // involved values are non-negative.
13052 // TODO: We could have better support for unsigned.
13053 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
13054 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
13055 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
13056 // use this fact to prove that LHS and RHS are non-negative.
13057 const SCEV *MinusOne = getMinusOne(LHS->getType());
13058 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
13059 FoundRHS) &&
13060 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
13061 FoundRHS))
13063 }
13064
13065 if (P != ICmpInst::ICMP_SGT)
13066 return false;
13067
13068 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
13069 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
13070 return Ext->getOperand();
13071 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
13072 // the constant in some cases.
13073 return S;
13074 };
13075
13076 // Acquire values from extensions.
13077 auto *OrigLHS = LHS;
13078 auto *OrigFoundLHS = FoundLHS;
13079 LHS = GetOpFromSExt(LHS);
13080 FoundLHS = GetOpFromSExt(FoundLHS);
13081
13082 // Is the SGT predicate can be proved trivially or using the found context.
13083 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
13084 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
13085 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
13086 FoundRHS, Depth + 1);
13087 };
13088
13089 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
13090 // We want to avoid creation of any new non-constant SCEV. Since we are
13091 // going to compare the operands to RHS, we should be certain that we don't
13092 // need any size extensions for this. So let's decline all cases when the
13093 // sizes of types of LHS and RHS do not match.
13094 // TODO: Maybe try to get RHS from sext to catch more cases?
13096 return false;
13097
13098 // Should not overflow.
13099 if (!LHSAddExpr->hasNoSignedWrap())
13100 return false;
13101
13102 SCEVUse LL = LHSAddExpr->getOperand(0);
13103 SCEVUse LR = LHSAddExpr->getOperand(1);
13104 auto *MinusOne = getMinusOne(RHS->getType());
13105
13106 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
13107 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
13108 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
13109 };
13110 // Try to prove the following rule:
13111 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
13112 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
13113 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
13114 return true;
13115 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
13116 Value *LL, *LR;
13117 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
13118
13119 using namespace llvm::PatternMatch;
13120
13121 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
13122 // Rules for division.
13123 // We are going to perform some comparisons with Denominator and its
13124 // derivative expressions. In general case, creating a SCEV for it may
13125 // lead to a complex analysis of the entire graph, and in particular it
13126 // can request trip count recalculation for the same loop. This would
13127 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
13128 // this, we only want to create SCEVs that are constants in this section.
13129 // So we bail if Denominator is not a constant.
13130 if (!isa<ConstantInt>(LR))
13131 return false;
13132
13133 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13134
13135 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13136 // then a SCEV for the numerator already exists and matches with FoundLHS.
13137 auto *Numerator = getExistingSCEV(LL);
13138 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13139 return false;
13140
13141 // Make sure that the numerator matches with FoundLHS and the denominator
13142 // is positive.
13143 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13144 return false;
13145
13146 auto *DTy = Denominator->getType();
13147 auto *FRHSTy = FoundRHS->getType();
13148 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13149 // One of types is a pointer and another one is not. We cannot extend
13150 // them properly to a wider type, so let us just reject this case.
13151 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13152 // to avoid this check.
13153 return false;
13154
13155 // Given that:
13156 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13157 auto *WTy = getWiderType(DTy, FRHSTy);
13158 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13159 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13160
13161 // Try to prove the following rule:
13162 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13163 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13164 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13165 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13166 if (isKnownNonPositive(RHS) &&
13167 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13168 return true;
13169
13170 // Try to prove the following rule:
13171 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13172 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13173 // If we divide it by Denominator > 2, then:
13174 // 1. If FoundLHS is negative, then the result is 0.
13175 // 2. If FoundLHS is non-negative, then the result is non-negative.
13176 // Anyways, the result is non-negative.
13177 auto *MinusOne = getMinusOne(WTy);
13178 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13179 if (isKnownNegative(RHS) &&
13180 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13181 return true;
13182 }
13183 }
13184
13185 // If our expression contained SCEVUnknown Phis, and we split it down and now
13186 // need to prove something for them, try to prove the predicate for every
13187 // possible incoming values of those Phis.
13188 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13189 return true;
13190
13191 return false;
13192}
13193
13195 const SCEV *RHS) {
13196 // zext x u<= sext x, sext x s<= zext x
13197 const SCEV *Op;
13198 switch (Pred) {
13199 case ICmpInst::ICMP_SGE:
13200 std::swap(LHS, RHS);
13201 [[fallthrough]];
13202 case ICmpInst::ICMP_SLE: {
13203 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13204 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13206 }
13207 case ICmpInst::ICMP_UGE:
13208 std::swap(LHS, RHS);
13209 [[fallthrough]];
13210 case ICmpInst::ICMP_ULE: {
13211 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13212 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13214 }
13215 default:
13216 return false;
13217 };
13218 llvm_unreachable("unhandled case");
13219}
13220
13221bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13222 SCEVUse LHS,
13223 SCEVUse RHS) {
13224 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13225 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13226 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13227 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13228 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13229}
13230
13231bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13232 const SCEV *LHS,
13233 const SCEV *RHS,
13234 const SCEV *FoundLHS,
13235 const SCEV *FoundRHS) {
13236 switch (Pred) {
13237 default:
13238 llvm_unreachable("Unexpected CmpPredicate value!");
13239 case ICmpInst::ICMP_EQ:
13240 case ICmpInst::ICMP_NE:
13241 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13242 return true;
13243 break;
13244 case ICmpInst::ICMP_SLT:
13245 case ICmpInst::ICMP_SLE:
13246 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13247 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13248 return true;
13249 break;
13250 case ICmpInst::ICMP_SGT:
13251 case ICmpInst::ICMP_SGE:
13252 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13253 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13254 return true;
13255 break;
13256 case ICmpInst::ICMP_ULT:
13257 case ICmpInst::ICMP_ULE:
13258 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13259 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13260 return true;
13261 break;
13262 case ICmpInst::ICMP_UGT:
13263 case ICmpInst::ICMP_UGE:
13264 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13265 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13266 return true;
13267 break;
13268 }
13269
13270 // Maybe it can be proved via operations?
13271 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13272 return true;
13273
13274 return false;
13275}
13276
13277bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13278 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13279 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13280 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13281 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13282 // reduce the compile time impact of this optimization.
13283 return false;
13284
13285 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13286 if (!Addend)
13287 return false;
13288
13289 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13290
13291 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13292 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13293 ConstantRange FoundLHSRange =
13294 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13295
13296 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13297 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13298
13299 // We can also compute the range of values for `LHS` that satisfy the
13300 // consequent, "`LHS` `Pred` `RHS`":
13301 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13302 // The antecedent implies the consequent if every value of `LHS` that
13303 // satisfies the antecedent also satisfies the consequent.
13304 return LHSRange.icmp(Pred, ConstRHS);
13305}
13306
13307bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13308 bool IsSigned) {
13309 assert(isKnownPositive(Stride) && "Positive stride expected!");
13310
13311 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13312 const SCEV *One = getOne(Stride->getType());
13313
13314 if (IsSigned) {
13315 APInt MaxRHS = getSignedRangeMax(RHS);
13316 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13317 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13318
13319 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13320 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13321 }
13322
13323 APInt MaxRHS = getUnsignedRangeMax(RHS);
13324 APInt MaxValue = APInt::getMaxValue(BitWidth);
13325 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13326
13327 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13328 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13329}
13330
13331bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13332 bool IsSigned) {
13333
13334 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13335 const SCEV *One = getOne(Stride->getType());
13336
13337 if (IsSigned) {
13338 APInt MinRHS = getSignedRangeMin(RHS);
13339 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13340 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13341
13342 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13343 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13344 }
13345
13346 APInt MinRHS = getUnsignedRangeMin(RHS);
13347 APInt MinValue = APInt::getMinValue(BitWidth);
13348 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13349
13350 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13351 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13352}
13353
13355 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13356 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13357 // expression fixes the case of N=0.
13358 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13359 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13360 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13361}
13362
13363const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13364 const SCEV *Stride,
13365 const SCEV *End,
13366 unsigned BitWidth,
13367 bool IsSigned) {
13368 // The logic in this function assumes we can represent a positive stride.
13369 // If we can't, the backedge-taken count must be zero.
13370 if (IsSigned && BitWidth == 1)
13371 return getZero(Stride->getType());
13372
13373 // This code below only been closely audited for negative strides in the
13374 // unsigned comparison case, it may be correct for signed comparison, but
13375 // that needs to be established.
13376 if (IsSigned && isKnownNegative(Stride))
13377 return getCouldNotCompute();
13378
13379 // Calculate the maximum backedge count based on the range of values
13380 // permitted by Start, End, and Stride.
13381 APInt MinStart =
13382 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13383
13384 APInt MinStride =
13385 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13386
13387 // We assume either the stride is positive, or the backedge-taken count
13388 // is zero. So force StrideForMaxBECount to be at least one.
13389 APInt One(BitWidth, 1);
13390 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13391 : APIntOps::umax(One, MinStride);
13392
13393 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13394 : APInt::getMaxValue(BitWidth);
13395 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13396
13397 // Although End can be a MAX expression we estimate MaxEnd considering only
13398 // the case End = RHS of the loop termination condition. This is safe because
13399 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13400 // taken count.
13401 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13402 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13403
13404 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13405 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13406 : APIntOps::umax(MaxEnd, MinStart);
13407
13408 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13409 getConstant(StrideForMaxBECount) /* Step */);
13410}
13411
13413ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13414 const Loop *L, bool IsSigned,
13415 bool ControlsOnlyExit, bool AllowPredicates) {
13417
13419 bool PredicatedIV = false;
13420 if (!IV) {
13421 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13422 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13423 if (AR && AR->getLoop() == L && AR->isAffine()) {
13424 auto canProveNUW = [&]() {
13425 // We can use the comparison to infer no-wrap flags only if it fully
13426 // controls the loop exit.
13427 if (!ControlsOnlyExit)
13428 return false;
13429
13430 if (!isLoopInvariant(RHS, L))
13431 return false;
13432
13433 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13434 // We need the sequence defined by AR to strictly increase in the
13435 // unsigned integer domain for the logic below to hold.
13436 return false;
13437
13438 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13439 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13440 // If RHS <=u Limit, then there must exist a value V in the sequence
13441 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13442 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13443 // overflow occurs. This limit also implies that a signed comparison
13444 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13445 // the high bits on both sides must be zero.
13446 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13447 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13448 Limit = Limit.zext(OuterBitWidth);
13449 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13450 };
13451 auto Flags = AR->getNoWrapFlags();
13452 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13453 Flags = setFlags(Flags, SCEV::FlagNUW);
13454
13455 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13456 if (AR->hasNoUnsignedWrap()) {
13457 // Emulate what getZeroExtendExpr would have done during construction
13458 // if we'd been able to infer the fact just above at that time.
13459 const SCEV *Step = AR->getStepRecurrence(*this);
13460 Type *Ty = ZExt->getType();
13461 auto *S = getAddRecExpr(
13463 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13465 }
13466 }
13467 }
13468 }
13469
13470
13471 if (!IV && AllowPredicates) {
13472 // Try to make this an AddRec using runtime tests, in the first X
13473 // iterations of this loop, where X is the SCEV expression found by the
13474 // algorithm below.
13475 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13476 PredicatedIV = true;
13477 }
13478
13479 // Avoid weird loops
13480 if (!IV || IV->getLoop() != L || !IV->isAffine())
13481 return getCouldNotCompute();
13482
13483 // A precondition of this method is that the condition being analyzed
13484 // reaches an exiting branch which dominates the latch. Given that, we can
13485 // assume that an increment which violates the nowrap specification and
13486 // produces poison must cause undefined behavior when the resulting poison
13487 // value is branched upon and thus we can conclude that the backedge is
13488 // taken no more often than would be required to produce that poison value.
13489 // Note that a well defined loop can exit on the iteration which violates
13490 // the nowrap specification if there is another exit (either explicit or
13491 // implicit/exceptional) which causes the loop to execute before the
13492 // exiting instruction we're analyzing would trigger UB.
13493 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13494 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13496
13497 const SCEV *Stride = IV->getStepRecurrence(*this);
13498
13499 bool PositiveStride = isKnownPositive(Stride);
13500
13501 // Avoid negative or zero stride values.
13502 if (!PositiveStride) {
13503 // We can compute the correct backedge taken count for loops with unknown
13504 // strides if we can prove that the loop is not an infinite loop with side
13505 // effects. Here's the loop structure we are trying to handle -
13506 //
13507 // i = start
13508 // do {
13509 // A[i] = i;
13510 // i += s;
13511 // } while (i < end);
13512 //
13513 // The backedge taken count for such loops is evaluated as -
13514 // (max(end, start + stride) - start - 1) /u stride
13515 //
13516 // The additional preconditions that we need to check to prove correctness
13517 // of the above formula is as follows -
13518 //
13519 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13520 // NoWrap flag).
13521 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13522 // no side effects within the loop)
13523 // c) loop has a single static exit (with no abnormal exits)
13524 //
13525 // Precondition a) implies that if the stride is negative, this is a single
13526 // trip loop. The backedge taken count formula reduces to zero in this case.
13527 //
13528 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13529 // then a zero stride means the backedge can't be taken without executing
13530 // undefined behavior.
13531 //
13532 // The positive stride case is the same as isKnownPositive(Stride) returning
13533 // true (original behavior of the function).
13534 //
13535 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13537 return getCouldNotCompute();
13538
13539 if (!isKnownNonZero(Stride)) {
13540 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13541 // if it might eventually be greater than start and if so, on which
13542 // iteration. We can't even produce a useful upper bound.
13543 if (!isLoopInvariant(RHS, L))
13544 return getCouldNotCompute();
13545
13546 // We allow a potentially zero stride, but we need to divide by stride
13547 // below. Since the loop can't be infinite and this check must control
13548 // the sole exit, we can infer the exit must be taken on the first
13549 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13550 // we know the numerator in the divides below must be zero, so we can
13551 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13552 // and produce the right result.
13553 // FIXME: Handle the case where Stride is poison?
13554 auto wouldZeroStrideBeUB = [&]() {
13555 // Proof by contradiction. Suppose the stride were zero. If we can
13556 // prove that the backedge *is* taken on the first iteration, then since
13557 // we know this condition controls the sole exit, we must have an
13558 // infinite loop. We can't have a (well defined) infinite loop per
13559 // check just above.
13560 // Note: The (Start - Stride) term is used to get the start' term from
13561 // (start' + stride,+,stride). Remember that we only care about the
13562 // result of this expression when stride == 0 at runtime.
13563 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13564 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13565 };
13566 if (!wouldZeroStrideBeUB()) {
13567 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13568 }
13569 }
13570 } else if (!NoWrap) {
13571 // Avoid proven overflow cases: this will ensure that the backedge taken
13572 // count will not generate any unsigned overflow.
13573 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13574 return getCouldNotCompute();
13575 }
13576
13577 // On all paths just preceeding, we established the following invariant:
13578 // IV can be assumed not to overflow up to and including the exiting
13579 // iteration. We proved this in one of two ways:
13580 // 1) We can show overflow doesn't occur before the exiting iteration
13581 // 1a) canIVOverflowOnLT, and b) step of one
13582 // 2) We can show that if overflow occurs, the loop must execute UB
13583 // before any possible exit.
13584 // Note that we have not yet proved RHS invariant (in general).
13585
13586 const SCEV *Start = IV->getStart();
13587
13588 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13589 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13590 // Use integer-typed versions for actual computation; we can't subtract
13591 // pointers in general.
13592 const SCEV *OrigStart = Start;
13593 const SCEV *OrigRHS = RHS;
13594 if (Start->getType()->isPointerTy()) {
13596 if (isa<SCEVCouldNotCompute>(Start))
13597 return Start;
13598 }
13599 if (RHS->getType()->isPointerTy()) {
13602 return RHS;
13603 }
13604
13605 const SCEV *End = nullptr, *BECount = nullptr,
13606 *BECountIfBackedgeTaken = nullptr;
13607 if (!isLoopInvariant(RHS, L)) {
13608 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13609 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13610 any(RHSAddRec->getNoWrapFlags())) {
13611 // The structure of loop we are trying to calculate backedge count of:
13612 //
13613 // left = left_start
13614 // right = right_start
13615 //
13616 // while(left < right){
13617 // ... do something here ...
13618 // left += s1; // stride of left is s1 (s1 > 0)
13619 // right += s2; // stride of right is s2 (s2 < 0)
13620 // }
13621 //
13622
13623 const SCEV *RHSStart = RHSAddRec->getStart();
13624 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13625
13626 // If Stride - RHSStride is positive and does not overflow, we can write
13627 // backedge count as ->
13628 // ceil((End - Start) /u (Stride - RHSStride))
13629 // Where, End = max(RHSStart, Start)
13630
13631 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13632 if (isKnownNegative(RHSStride) &&
13633 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13634 RHSStride)) {
13635
13636 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13637 if (isKnownPositive(Denominator)) {
13638 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13639 : getUMaxExpr(RHSStart, Start);
13640
13641 // We can do this because End >= Start, as End = max(RHSStart, Start)
13642 const SCEV *Delta = getMinusSCEV(End, Start);
13643
13644 BECount = getUDivCeilSCEV(Delta, Denominator);
13645 BECountIfBackedgeTaken =
13646 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13647 }
13648 }
13649 }
13650 if (BECount == nullptr) {
13651 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13652 // given the start, stride and max value for the end bound of the
13653 // loop (RHS), and the fact that IV does not overflow (which is
13654 // checked above).
13655 const SCEV *MaxBECount = computeMaxBECountForLT(
13656 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13657 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13658 MaxBECount, false /*MaxOrZero*/, Predicates);
13659 }
13660 } else {
13661 // We use the expression (max(End,Start)-Start)/Stride to describe the
13662 // backedge count, as if the backedge is taken at least once
13663 // max(End,Start) is End and so the result is as above, and if not
13664 // max(End,Start) is Start so we get a backedge count of zero.
13665 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13666 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13667 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13668 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13669 // Can we prove (max(RHS,Start) > Start - Stride?
13670 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13671 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13672 // In this case, we can use a refined formula for computing backedge
13673 // taken count. The general formula remains:
13674 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13675 // We want to use the alternate formula:
13676 // "((End - 1) - (Start - Stride)) /u Stride"
13677 // Let's do a quick case analysis to show these are equivalent under
13678 // our precondition that max(RHS,Start) > Start - Stride.
13679 // * For RHS <= Start, the backedge-taken count must be zero.
13680 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13681 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13682 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13683 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13684 // reducing this to the stride of 1 case.
13685 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13686 // Stride".
13687 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13688 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13689 // "((RHS - (Start - Stride) - 1) /u Stride".
13690 // Our preconditions trivially imply no overflow in that form.
13691 const SCEV *MinusOne = getMinusOne(Stride->getType());
13692 const SCEV *Numerator =
13693 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13694 BECount = getUDivExpr(Numerator, Stride);
13695 }
13696
13697 if (!BECount) {
13698 auto canProveRHSGreaterThanEqualStart = [&]() {
13699 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13700 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13701 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13702
13703 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13704 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13705 return true;
13706
13707 // (RHS > Start - 1) implies RHS >= Start.
13708 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13709 // "Start - 1" doesn't overflow.
13710 // * For signed comparison, if Start - 1 does overflow, it's equal
13711 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13712 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13713 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13714 //
13715 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13716 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13717 auto *StartMinusOne =
13718 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13719 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13720 };
13721
13722 // If we know that RHS >= Start in the context of loop, then we know
13723 // that max(RHS, Start) = RHS at this point.
13724 if (canProveRHSGreaterThanEqualStart()) {
13725 End = RHS;
13726 } else {
13727 // If RHS < Start, the backedge will be taken zero times. So in
13728 // general, we can write the backedge-taken count as:
13729 //
13730 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13731 //
13732 // We convert it to the following to make it more convenient for SCEV:
13733 //
13734 // ceil(max(RHS, Start) - Start) / Stride
13735 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13736
13737 // See what would happen if we assume the backedge is taken. This is
13738 // used to compute MaxBECount.
13739 BECountIfBackedgeTaken =
13740 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13741 }
13742
13743 // At this point, we know:
13744 //
13745 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13746 // 2. The index variable doesn't overflow.
13747 //
13748 // Therefore, we know N exists such that
13749 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13750 // doesn't overflow.
13751 //
13752 // Using this information, try to prove whether the addition in
13753 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13754 const SCEV *One = getOne(Stride->getType());
13755 bool MayAddOverflow = [&] {
13756 if (isKnownToBeAPowerOfTwo(Stride)) {
13757 // Suppose Stride is a power of two, and Start/End are unsigned
13758 // integers. Let UMAX be the largest representable unsigned
13759 // integer.
13760 //
13761 // By the preconditions of this function, we know
13762 // "(Start + Stride * N) >= End", and this doesn't overflow.
13763 // As a formula:
13764 //
13765 // End <= (Start + Stride * N) <= UMAX
13766 //
13767 // Subtracting Start from all the terms:
13768 //
13769 // End - Start <= Stride * N <= UMAX - Start
13770 //
13771 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13772 //
13773 // End - Start <= Stride * N <= UMAX
13774 //
13775 // Stride * N is a multiple of Stride. Therefore,
13776 //
13777 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13778 //
13779 // Since Stride is a power of two, UMAX + 1 is divisible by
13780 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13781 // write:
13782 //
13783 // End - Start <= Stride * N <= UMAX - Stride - 1
13784 //
13785 // Dropping the middle term:
13786 //
13787 // End - Start <= UMAX - Stride - 1
13788 //
13789 // Adding Stride - 1 to both sides:
13790 //
13791 // (End - Start) + (Stride - 1) <= UMAX
13792 //
13793 // In other words, the addition doesn't have unsigned overflow.
13794 //
13795 // A similar proof works if we treat Start/End as signed values.
13796 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13797 // to use signed max instead of unsigned max. Note that we're
13798 // trying to prove a lack of unsigned overflow in either case.
13799 return false;
13800 }
13801 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13802 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13803 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13804 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13805 // 1 <s End.
13806 //
13807 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13808 // End.
13809 return false;
13810 }
13811 return true;
13812 }();
13813
13814 const SCEV *Delta = getMinusSCEV(End, Start);
13815 if (!MayAddOverflow) {
13816 // floor((D + (S - 1)) / S)
13817 // We prefer this formulation if it's legal because it's fewer
13818 // operations.
13819 BECount =
13820 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13821 } else {
13822 BECount = getUDivCeilSCEV(Delta, Stride);
13823 }
13824 }
13825 }
13826
13827 const SCEV *ConstantMaxBECount;
13828 bool MaxOrZero = false;
13829 if (isa<SCEVConstant>(BECount)) {
13830 ConstantMaxBECount = BECount;
13831 } else if (BECountIfBackedgeTaken &&
13832 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13833 // If we know exactly how many times the backedge will be taken if it's
13834 // taken at least once, then the backedge count will either be that or
13835 // zero.
13836 ConstantMaxBECount = BECountIfBackedgeTaken;
13837 MaxOrZero = true;
13838 } else {
13839 ConstantMaxBECount = computeMaxBECountForLT(
13840 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13841 }
13842
13843 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13844 !isa<SCEVCouldNotCompute>(BECount))
13845 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13846
13847 const SCEV *SymbolicMaxBECount =
13848 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13849 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13850 Predicates);
13851}
13852
13853ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13854 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13855 bool ControlsOnlyExit, bool AllowPredicates) {
13857 // We handle only IV > Invariant
13858 if (!isLoopInvariant(RHS, L))
13859 return getCouldNotCompute();
13860
13861 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13862 if (!IV && AllowPredicates)
13863 // Try to make this an AddRec using runtime tests, in the first X
13864 // iterations of this loop, where X is the SCEV expression found by the
13865 // algorithm below.
13866 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13867
13868 // Avoid weird loops
13869 if (!IV || IV->getLoop() != L || !IV->isAffine())
13870 return getCouldNotCompute();
13871
13872 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13873 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13875
13876 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13877
13878 // Avoid negative or zero stride values
13879 if (!isKnownPositive(Stride))
13880 return getCouldNotCompute();
13881
13882 // Avoid proven overflow cases: this will ensure that the backedge taken count
13883 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13884 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13885 // behaviors like the case of C language.
13886 if (!Stride->isOne() && !NoWrap)
13887 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13888 return getCouldNotCompute();
13889
13890 const SCEV *Start = IV->getStart();
13891 const SCEV *End = RHS;
13892 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13893 // If we know that Start >= RHS in the context of loop, then we know that
13894 // min(RHS, Start) = RHS at this point.
13896 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13897 End = RHS;
13898 else
13899 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13900 }
13901
13902 if (Start->getType()->isPointerTy()) {
13904 if (isa<SCEVCouldNotCompute>(Start))
13905 return Start;
13906 }
13907 if (End->getType()->isPointerTy()) {
13908 End = getLosslessPtrToIntExpr(End);
13909 if (isa<SCEVCouldNotCompute>(End))
13910 return End;
13911 }
13912
13913 // Compute ((Start - End) + (Stride - 1)) / Stride.
13914 // FIXME: This can overflow. Holding off on fixing this for now;
13915 // howManyGreaterThans will hopefully be gone soon.
13916 const SCEV *One = getOne(Stride->getType());
13917 const SCEV *BECount = getUDivExpr(
13918 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13919
13920 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13922
13923 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13924 : getUnsignedRangeMin(Stride);
13925
13926 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13927 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13928 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13929
13930 // Although End can be a MIN expression we estimate MinEnd considering only
13931 // the case End = RHS. This is safe because in the other case (Start - End)
13932 // is zero, leading to a zero maximum backedge taken count.
13933 APInt MinEnd =
13934 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13935 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13936
13937 const SCEV *ConstantMaxBECount =
13938 isa<SCEVConstant>(BECount)
13939 ? BECount
13940 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13941 getConstant(MinStride));
13942
13943 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13944 ConstantMaxBECount = BECount;
13945 const SCEV *SymbolicMaxBECount =
13946 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13947
13948 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13949 Predicates);
13950}
13951
13953 ScalarEvolution &SE) const {
13954 if (Range.isFullSet()) // Infinite loop.
13955 return SE.getCouldNotCompute();
13956
13957 // If the start is a non-zero constant, shift the range to simplify things.
13958 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13959 if (!SC->getValue()->isZero()) {
13961 Operands[0] = SE.getZero(SC->getType());
13962 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13964 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13965 return ShiftedAddRec->getNumIterationsInRange(
13966 Range.subtract(SC->getAPInt()), SE);
13967 // This is strange and shouldn't happen.
13968 return SE.getCouldNotCompute();
13969 }
13970
13971 // The only time we can solve this is when we have all constant indices.
13972 // Otherwise, we cannot determine the overflow conditions.
13973 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13974 return SE.getCouldNotCompute();
13975
13976 // Okay at this point we know that all elements of the chrec are constants and
13977 // that the start element is zero.
13978
13979 // First check to see if the range contains zero. If not, the first
13980 // iteration exits.
13981 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13982 if (!Range.contains(APInt(BitWidth, 0)))
13983 return SE.getZero(getType());
13984
13985 if (isAffine()) {
13986 // If this is an affine expression then we have this situation:
13987 // Solve {0,+,A} in Range === Ax in Range
13988
13989 // We know that zero is in the range. If A is positive then we know that
13990 // the upper value of the range must be the first possible exit value.
13991 // If A is negative then the lower of the range is the last possible loop
13992 // value. Also note that we already checked for a full range.
13993 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13994 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13995
13996 // The exit value should be (End+A)/A.
13997 APInt ExitVal = (End + A).udiv(A);
13998 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13999
14000 // Evaluate at the exit value. If we really did fall out of the valid
14001 // range, then we computed our trip count, otherwise wrap around or other
14002 // things must have happened.
14003 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
14004 if (Range.contains(Val->getValue()))
14005 return SE.getCouldNotCompute(); // Something strange happened
14006
14007 // Ensure that the previous value is in the range.
14008 assert(Range.contains(
14010 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
14011 "Linear scev computation is off in a bad way!");
14012 return SE.getConstant(ExitValue);
14013 }
14014
14015 if (isQuadratic()) {
14016 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
14017 return SE.getConstant(*S);
14018 }
14019
14020 return SE.getCouldNotCompute();
14021}
14022
14023const SCEVAddRecExpr *
14025 assert(getNumOperands() > 1 && "AddRec with zero step?");
14026 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
14027 // but in this case we cannot guarantee that the value returned will be an
14028 // AddRec because SCEV does not have a fixed point where it stops
14029 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
14030 // may happen if we reach arithmetic depth limit while simplifying. So we
14031 // construct the returned value explicitly.
14033 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
14034 // (this + Step) is {A+B,+,B+C,+...,+,N}.
14035 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
14036 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
14037 // We know that the last operand is not a constant zero (otherwise it would
14038 // have been popped out earlier). This guarantees us that if the result has
14039 // the same last operand, then it will also not be popped out, meaning that
14040 // the returned value will be an AddRec.
14041 const SCEV *Last = getOperand(getNumOperands() - 1);
14042 assert(!Last->isZero() && "Recurrency with zero step?");
14043 Ops.push_back(Last);
14046}
14047
14048// Return true when S contains at least an undef value.
14050 return SCEVExprContains(
14051 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
14052}
14053
14054// Return true when S contains a value that is a nullptr.
14056 return SCEVExprContains(S, [](const SCEV *S) {
14057 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
14058 return SU->getValue() == nullptr;
14059 return false;
14060 });
14061}
14062
14063/// Return the size of an element read or written by Inst.
14065 Type *Ty;
14066 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
14067 Ty = Store->getValueOperand()->getType();
14068 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
14069 Ty = Load->getType();
14070 else
14071 return nullptr;
14072
14074 return getSizeOfExpr(ETy, Ty);
14075}
14076
14077//===----------------------------------------------------------------------===//
14078// SCEVCallbackVH Class Implementation
14079//===----------------------------------------------------------------------===//
14080
14082 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
14083 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
14084 SE->ConstantEvolutionLoopExitValue.erase(PN);
14085 SE->eraseValueFromMap(getValPtr());
14086 // this now dangles!
14087}
14088
14089void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
14090 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
14091
14092 // Forget all the expressions associated with users of the old value,
14093 // so that future queries will recompute the expressions using the new
14094 // value.
14095 SE->forgetValue(getValPtr());
14096 // this now dangles!
14097}
14098
14099ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
14100 : CallbackVH(V), SE(se) {}
14101
14102//===----------------------------------------------------------------------===//
14103// ScalarEvolution Class Implementation
14104//===----------------------------------------------------------------------===//
14105
14108 LoopInfo &LI)
14109 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
14110 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
14111 LoopDispositions(64), BlockDispositions(64) {
14112 // To use guards for proving predicates, we need to scan every instruction in
14113 // relevant basic blocks, and not just terminators. Doing this is a waste of
14114 // time if the IR does not actually contain any calls to
14115 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
14116 //
14117 // This pessimizes the case where a pass that preserves ScalarEvolution wants
14118 // to _add_ guards to the module when there weren't any before, and wants
14119 // ScalarEvolution to optimize based on those guards. For now we prefer to be
14120 // efficient in lieu of being smart in that rather obscure case.
14121
14122 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
14123 F.getParent(), Intrinsic::experimental_guard);
14124 HasGuards = GuardDecl && !GuardDecl->use_empty();
14125}
14126
14128 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
14129 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
14130 ValueExprMap(std::move(Arg.ValueExprMap)),
14131 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
14132 PendingMerges(std::move(Arg.PendingMerges)),
14133 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14134 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14135 PredicatedBackedgeTakenCounts(
14136 std::move(Arg.PredicatedBackedgeTakenCounts)),
14137 BECountUsers(std::move(Arg.BECountUsers)),
14138 ConstantEvolutionLoopExitValue(
14139 std::move(Arg.ConstantEvolutionLoopExitValue)),
14140 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14141 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14142 LoopDispositions(std::move(Arg.LoopDispositions)),
14143 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14144 BlockDispositions(std::move(Arg.BlockDispositions)),
14145 SCEVUsers(std::move(Arg.SCEVUsers)),
14146 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14147 SignedRanges(std::move(Arg.SignedRanges)),
14148 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14149 UniquePreds(std::move(Arg.UniquePreds)),
14150 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14151 LoopUsers(std::move(Arg.LoopUsers)),
14152 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14153 FirstUnknown(Arg.FirstUnknown) {
14154 Arg.FirstUnknown = nullptr;
14155}
14156
14158 // Iterate through all the SCEVUnknown instances and call their
14159 // destructors, so that they release their references to their values.
14160 for (SCEVUnknown *U = FirstUnknown; U;) {
14161 SCEVUnknown *Tmp = U;
14162 U = U->Next;
14163 Tmp->~SCEVUnknown();
14164 }
14165 FirstUnknown = nullptr;
14166
14167 ExprValueMap.clear();
14168 ValueExprMap.clear();
14169 HasRecMap.clear();
14170 BackedgeTakenCounts.clear();
14171 PredicatedBackedgeTakenCounts.clear();
14172
14173 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14174 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14175 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14176 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14177}
14178
14182
14183/// When printing a top-level SCEV for trip counts, it's helpful to include
14184/// a type for constants which are otherwise hard to disambiguate.
14185static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14186 if (isa<SCEVConstant>(S))
14187 OS << *S->getType() << " ";
14188 OS << *S;
14189}
14190
14192 const Loop *L) {
14193 // Print all inner loops first
14194 for (Loop *I : *L)
14195 PrintLoopInfo(OS, SE, I);
14196
14197 OS << "Loop ";
14198 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14199 OS << ": ";
14200
14201 SmallVector<BasicBlock *, 8> ExitingBlocks;
14202 L->getExitingBlocks(ExitingBlocks);
14203 if (ExitingBlocks.size() != 1)
14204 OS << "<multiple exits> ";
14205
14206 auto *BTC = SE->getBackedgeTakenCount(L);
14207 if (!isa<SCEVCouldNotCompute>(BTC)) {
14208 OS << "backedge-taken count is ";
14209 PrintSCEVWithTypeHint(OS, BTC);
14210 } else
14211 OS << "Unpredictable backedge-taken count.";
14212 OS << "\n";
14213
14214 if (ExitingBlocks.size() > 1)
14215 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14216 OS << " exit count for " << ExitingBlock->getName() << ": ";
14217 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14218 PrintSCEVWithTypeHint(OS, EC);
14219 if (isa<SCEVCouldNotCompute>(EC)) {
14220 // Retry with predicates.
14222 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14223 if (!isa<SCEVCouldNotCompute>(EC)) {
14224 OS << "\n predicated exit count for " << ExitingBlock->getName()
14225 << ": ";
14226 PrintSCEVWithTypeHint(OS, EC);
14227 OS << "\n Predicates:\n";
14228 for (const auto *P : Predicates)
14229 P->print(OS, 4);
14230 }
14231 }
14232 OS << "\n";
14233 }
14234
14235 OS << "Loop ";
14236 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14237 OS << ": ";
14238
14239 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14240 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14241 OS << "constant max backedge-taken count is ";
14242 PrintSCEVWithTypeHint(OS, ConstantBTC);
14244 OS << ", actual taken count either this or zero.";
14245 } else {
14246 OS << "Unpredictable constant max backedge-taken count. ";
14247 }
14248
14249 OS << "\n"
14250 "Loop ";
14251 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14252 OS << ": ";
14253
14254 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14255 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14256 OS << "symbolic max backedge-taken count is ";
14257 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14259 OS << ", actual taken count either this or zero.";
14260 } else {
14261 OS << "Unpredictable symbolic max backedge-taken count. ";
14262 }
14263 OS << "\n";
14264
14265 if (ExitingBlocks.size() > 1)
14266 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14267 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14268 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14270 PrintSCEVWithTypeHint(OS, ExitBTC);
14271 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14272 // Retry with predicates.
14274 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14276 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14277 OS << "\n predicated symbolic max exit count for "
14278 << ExitingBlock->getName() << ": ";
14279 PrintSCEVWithTypeHint(OS, ExitBTC);
14280 OS << "\n Predicates:\n";
14281 for (const auto *P : Predicates)
14282 P->print(OS, 4);
14283 }
14284 }
14285 OS << "\n";
14286 }
14287
14289 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14290 if (PBT != BTC) {
14291 OS << "Loop ";
14292 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14293 OS << ": ";
14294 if (!isa<SCEVCouldNotCompute>(PBT)) {
14295 OS << "Predicated backedge-taken count is ";
14296 PrintSCEVWithTypeHint(OS, PBT);
14297 } else
14298 OS << "Unpredictable predicated backedge-taken count.";
14299 OS << "\n";
14300 OS << " Predicates:\n";
14301 for (const auto *P : Preds)
14302 P->print(OS, 4);
14303 }
14304 Preds.clear();
14305
14306 auto *PredConstantMax =
14308 if (PredConstantMax != ConstantBTC) {
14309 OS << "Loop ";
14310 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14311 OS << ": ";
14312 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14313 OS << "Predicated constant max backedge-taken count is ";
14314 PrintSCEVWithTypeHint(OS, PredConstantMax);
14315 } else
14316 OS << "Unpredictable predicated constant max backedge-taken count.";
14317 OS << "\n";
14318 OS << " Predicates:\n";
14319 for (const auto *P : Preds)
14320 P->print(OS, 4);
14321 }
14322 Preds.clear();
14323
14324 auto *PredSymbolicMax =
14326 if (SymbolicBTC != PredSymbolicMax) {
14327 OS << "Loop ";
14328 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14329 OS << ": ";
14330 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14331 OS << "Predicated symbolic max backedge-taken count is ";
14332 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14333 } else
14334 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14335 OS << "\n";
14336 OS << " Predicates:\n";
14337 for (const auto *P : Preds)
14338 P->print(OS, 4);
14339 }
14340
14342 OS << "Loop ";
14343 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14344 OS << ": ";
14345 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14346 }
14347}
14348
14349namespace llvm {
14350// Note: these overloaded operators need to be in the llvm namespace for them
14351// to be resolved correctly. If we put them outside the llvm namespace, the
14352//
14353// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14354//
14355// code below "breaks" and start printing raw enum values as opposed to the
14356// string values.
14359 switch (LD) {
14361 OS << "Variant";
14362 break;
14364 OS << "Invariant";
14365 break;
14367 OS << "Uniform";
14368 break;
14370 OS << "Computable";
14371 break;
14372 }
14373 return OS;
14374}
14375
14378 switch (BD) {
14380 OS << "DoesNotDominate";
14381 break;
14383 OS << "Dominates";
14384 break;
14386 OS << "ProperlyDominates";
14387 break;
14388 }
14389 return OS;
14390}
14391} // namespace llvm
14392
14394 // ScalarEvolution's implementation of the print method is to print
14395 // out SCEV values of all instructions that are interesting. Doing
14396 // this potentially causes it to create new SCEV objects though,
14397 // which technically conflicts with the const qualifier. This isn't
14398 // observable from outside the class though, so casting away the
14399 // const isn't dangerous.
14400 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14401
14402 if (ClassifyExpressions) {
14403 OS << "Classifying expressions for: ";
14404 F.printAsOperand(OS, /*PrintType=*/false);
14405 OS << "\n";
14406 for (Instruction &I : instructions(F))
14407 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14408 OS << I << '\n';
14409 OS << " --> ";
14410 const SCEV *SV = SE.getSCEV(&I);
14411 SV->print(OS);
14412 if (!isa<SCEVCouldNotCompute>(SV)) {
14413 OS << " U: ";
14414 SE.getUnsignedRange(SV).print(OS);
14415 OS << " S: ";
14416 SE.getSignedRange(SV).print(OS);
14417 }
14418
14419 const Loop *L = LI.getLoopFor(I.getParent());
14420
14421 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14422 if (AtUse != SV) {
14423 OS << " --> ";
14424 AtUse->print(OS);
14425 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14426 OS << " U: ";
14427 SE.getUnsignedRange(AtUse).print(OS);
14428 OS << " S: ";
14429 SE.getSignedRange(AtUse).print(OS);
14430 }
14431 }
14432
14433 if (L) {
14434 OS << "\t\t" "Exits: ";
14435 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14436 if (!SE.isLoopInvariant(ExitValue, L)) {
14437 OS << "<<Unknown>>";
14438 } else {
14439 OS << *ExitValue;
14440 }
14441
14442 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14443 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14444 OS << LS;
14445 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14446 OS << ": " << SE.getLoopDisposition(SV, Iter);
14447 }
14448
14449 for (const auto *InnerL : depth_first(L)) {
14450 if (InnerL == L)
14451 continue;
14452 OS << LS;
14453 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14454 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14455 }
14456
14457 OS << " }";
14458 }
14459
14460 OS << "\n";
14461 }
14462 }
14463
14464 OS << "Determining loop execution counts for: ";
14465 F.printAsOperand(OS, /*PrintType=*/false);
14466 OS << "\n";
14467 for (Loop *I : LI)
14468 PrintLoopInfo(OS, &SE, I);
14469}
14470
14473 auto &Values = LoopDispositions[S];
14474 for (auto &V : Values) {
14475 if (V.getPointer() == L)
14476 return V.getInt();
14477 }
14478 Values.emplace_back(L, LoopVariant);
14479 LoopDisposition D = computeLoopDisposition(S, L);
14480 auto &Values2 = LoopDispositions[S];
14481 for (auto &V : llvm::reverse(Values2)) {
14482 if (V.getPointer() == L) {
14483 V.setInt(D);
14484 break;
14485 }
14486 }
14487 return D;
14488}
14489
14491ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14492 switch (S->getSCEVType()) {
14493 case scConstant:
14494 case scVScale:
14495 return LoopInvariant;
14496 case scAddRecExpr: {
14497 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14498
14499 // If L is the addrec's loop, it's computable.
14500 if (AR->getLoop() == L)
14501 return LoopComputable;
14502
14503 // Add recurrences are never invariant in the function-body (null loop).
14504 if (!L)
14505 return LoopVariant;
14506
14507 // Everything that is not defined at loop entry is variant.
14508 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) {
14509 if (L->contains(AR->getLoop()) &&
14510 llvm::all_of(AR->operands(),
14511 [&](const SCEV *Op) { return isLoopUniform(Op, L); }))
14512 return LoopUniform;
14513
14514 return LoopVariant;
14515 }
14516 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14517 " dominate the contained loop's header?");
14518
14519 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14520 if (AR->getLoop()->contains(L))
14521 return LoopInvariant;
14522
14523 // This recurrence is variant w.r.t. L if any of its operands
14524 // are variant.
14525 for (SCEVUse Op : AR->operands())
14526 if (!isLoopInvariant(Op, L))
14527 return LoopVariant;
14528
14529 // Otherwise it's loop-invariant.
14530 return LoopInvariant;
14531 }
14532 case scTruncate:
14533 case scZeroExtend:
14534 case scSignExtend:
14535 case scPtrToAddr:
14536 case scPtrToInt:
14537 case scAddExpr:
14538 case scMulExpr:
14539 case scUDivExpr:
14540 case scUMaxExpr:
14541 case scSMaxExpr:
14542 case scUMinExpr:
14543 case scSMinExpr:
14544 case scSequentialUMinExpr: {
14545 bool HasVarying = false;
14546 bool HasUniform = false;
14547 for (SCEVUse Op : S->operands()) {
14549 if (D == LoopVariant)
14550 return LoopVariant;
14551 if (D == LoopComputable)
14552 HasVarying = true;
14553 if (D == LoopUniform)
14554 HasUniform = true;
14555 }
14556 return HasVarying ? (HasUniform ? LoopVariant : LoopComputable)
14557 : (HasUniform ? LoopUniform : LoopInvariant);
14558 }
14559 case scUnknown:
14560 // All non-instruction values are loop invariant. All instructions are loop
14561 // invariant if they are not contained in the specified loop.
14562 // Instructions are never considered invariant in the function body
14563 // (null loop) because they are defined within the "loop".
14565 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14566 return LoopInvariant;
14567 case scCouldNotCompute:
14568 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14569 }
14570 llvm_unreachable("Unknown SCEV kind!");
14571}
14572
14573bool ScalarEvolution::isLoopUniform(const SCEV *S, const Loop *L) {
14575 return D == LoopUniform || D == LoopInvariant;
14576}
14577
14579 return getLoopDisposition(S, L) == LoopInvariant;
14580}
14581
14583 return getLoopDisposition(S, L) == LoopComputable;
14584}
14585
14588 auto &Values = BlockDispositions[S];
14589 for (auto &V : Values) {
14590 if (V.getPointer() == BB)
14591 return V.getInt();
14592 }
14593 Values.emplace_back(BB, DoesNotDominateBlock);
14594 BlockDisposition D = computeBlockDisposition(S, BB);
14595 auto &Values2 = BlockDispositions[S];
14596 for (auto &V : llvm::reverse(Values2)) {
14597 if (V.getPointer() == BB) {
14598 V.setInt(D);
14599 break;
14600 }
14601 }
14602 return D;
14603}
14604
14606ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14607 switch (S->getSCEVType()) {
14608 case scConstant:
14609 case scVScale:
14611 case scAddRecExpr: {
14612 // This uses a "dominates" query instead of "properly dominates" query
14613 // to test for proper dominance too, because the instruction which
14614 // produces the addrec's value is a PHI, and a PHI effectively properly
14615 // dominates its entire containing block.
14616 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14617 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14618 return DoesNotDominateBlock;
14619
14620 // Fall through into SCEVNAryExpr handling.
14621 [[fallthrough]];
14622 }
14623 case scTruncate:
14624 case scZeroExtend:
14625 case scSignExtend:
14626 case scPtrToAddr:
14627 case scPtrToInt:
14628 case scAddExpr:
14629 case scMulExpr:
14630 case scUDivExpr:
14631 case scUMaxExpr:
14632 case scSMaxExpr:
14633 case scUMinExpr:
14634 case scSMinExpr:
14635 case scSequentialUMinExpr: {
14636 bool Proper = true;
14637 for (const SCEV *NAryOp : S->operands()) {
14639 if (D == DoesNotDominateBlock)
14640 return DoesNotDominateBlock;
14641 if (D == DominatesBlock)
14642 Proper = false;
14643 }
14644 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14645 }
14646 case scUnknown:
14647 if (Instruction *I =
14649 if (I->getParent() == BB)
14650 return DominatesBlock;
14651 if (DT.properlyDominates(I->getParent(), BB))
14653 return DoesNotDominateBlock;
14654 }
14656 case scCouldNotCompute:
14657 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14658 }
14659 llvm_unreachable("Unknown SCEV kind!");
14660}
14661
14662bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14663 return getBlockDisposition(S, BB) >= DominatesBlock;
14664}
14665
14668}
14669
14670bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14671 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14672}
14673
14674void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14675 bool Predicated) {
14676 auto &BECounts =
14677 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14678 auto It = BECounts.find(L);
14679 if (It != BECounts.end()) {
14680 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14681 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14682 if (!isa<SCEVConstant>(S)) {
14683 auto UserIt = BECountUsers.find(S);
14684 assert(UserIt != BECountUsers.end());
14685 UserIt->second.erase({L, Predicated});
14686 }
14687 }
14688 }
14689 BECounts.erase(It);
14690 }
14691}
14692
14693void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14694 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14695 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14696
14697 while (!Worklist.empty()) {
14698 const SCEV *Curr = Worklist.pop_back_val();
14699 auto Users = SCEVUsers.find(Curr);
14700 if (Users != SCEVUsers.end())
14701 for (const auto *User : Users->second)
14702 if (ToForget.insert(User).second)
14703 Worklist.push_back(User);
14704 }
14705
14706 for (const auto *S : ToForget)
14707 forgetMemoizedResultsImpl(S);
14708
14709 PredicatedSCEVRewrites.remove_if(
14710 [&](const auto &Entry) { return ToForget.count(Entry.first.first); });
14711}
14712
14713void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14714 LoopDispositions.erase(S);
14715 BlockDispositions.erase(S);
14716 UnsignedRanges.erase(S);
14717 SignedRanges.erase(S);
14718 HasRecMap.erase(S);
14719 ConstantMultipleCache.erase(S);
14720
14721 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14722 UnsignedWrapViaInductionTried.erase(AR);
14723 SignedWrapViaInductionTried.erase(AR);
14724 }
14725
14726 auto ExprIt = ExprValueMap.find(S);
14727 if (ExprIt != ExprValueMap.end()) {
14728 for (Value *V : ExprIt->second) {
14729 auto ValueIt = ValueExprMap.find_as(V);
14730 if (ValueIt != ValueExprMap.end())
14731 ValueExprMap.erase(ValueIt);
14732 }
14733 ExprValueMap.erase(ExprIt);
14734 }
14735
14736 auto ScopeIt = ValuesAtScopes.find(S);
14737 if (ScopeIt != ValuesAtScopes.end()) {
14738 for (const auto &Pair : ScopeIt->second)
14739 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14740 llvm::erase(ValuesAtScopesUsers[Pair.second],
14741 std::make_pair(Pair.first, S));
14742 ValuesAtScopes.erase(ScopeIt);
14743 }
14744
14745 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14746 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14747 for (const auto &Pair : ScopeUserIt->second)
14748 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14749 ValuesAtScopesUsers.erase(ScopeUserIt);
14750 }
14751
14752 auto BEUsersIt = BECountUsers.find(S);
14753 if (BEUsersIt != BECountUsers.end()) {
14754 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14755 auto Copy = BEUsersIt->second;
14756 for (const auto &Pair : Copy)
14757 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14758 BECountUsers.erase(BEUsersIt);
14759 }
14760
14761 auto FoldUser = FoldCacheUser.find(S);
14762 if (FoldUser != FoldCacheUser.end())
14763 for (auto &KV : FoldUser->second)
14764 FoldCache.erase(KV);
14765 FoldCacheUser.erase(S);
14766}
14767
14768void
14769ScalarEvolution::getUsedLoops(const SCEV *S,
14770 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14771 struct FindUsedLoops {
14772 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14773 : LoopsUsed(LoopsUsed) {}
14774 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14775 bool follow(const SCEV *S) {
14776 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14777 LoopsUsed.insert(AR->getLoop());
14778 return true;
14779 }
14780
14781 bool isDone() const { return false; }
14782 };
14783
14784 FindUsedLoops F(LoopsUsed);
14785 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14786}
14787
14788void ScalarEvolution::getReachableBlocks(
14791 Worklist.push_back(&F.getEntryBlock());
14792 while (!Worklist.empty()) {
14793 BasicBlock *BB = Worklist.pop_back_val();
14794 if (!Reachable.insert(BB).second)
14795 continue;
14796
14797 Value *Cond;
14798 BasicBlock *TrueBB, *FalseBB;
14799 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14800 m_BasicBlock(FalseBB)))) {
14801 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14802 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14803 continue;
14804 }
14805
14806 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14807 const SCEV *L = getSCEV(Cmp->getOperand(0));
14808 const SCEV *R = getSCEV(Cmp->getOperand(1));
14809 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14810 Worklist.push_back(TrueBB);
14811 continue;
14812 }
14813 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14814 R)) {
14815 Worklist.push_back(FalseBB);
14816 continue;
14817 }
14818 }
14819 }
14820
14821 append_range(Worklist, successors(BB));
14822 }
14823}
14824
14826 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14827 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14828
14829 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14830
14831 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14832 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14833 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14834
14835 const SCEV *visitConstant(const SCEVConstant *Constant) {
14836 return SE.getConstant(Constant->getAPInt());
14837 }
14838
14839 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14840 return SE.getUnknown(Expr->getValue());
14841 }
14842
14843 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14844 return SE.getCouldNotCompute();
14845 }
14846 };
14847
14848 SCEVMapper SCM(SE2);
14849 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14850 SE2.getReachableBlocks(ReachableBlocks, F);
14851
14852 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14853 if (containsUndefs(Old) || containsUndefs(New)) {
14854 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14855 // not propagate undef aggressively). This means we can (and do) fail
14856 // verification in cases where a transform makes a value go from "undef"
14857 // to "undef+1" (say). The transform is fine, since in both cases the
14858 // result is "undef", but SCEV thinks the value increased by 1.
14859 return nullptr;
14860 }
14861
14862 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14863 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14864 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14865 return nullptr;
14866
14867 return Delta;
14868 };
14869
14870 while (!LoopStack.empty()) {
14871 auto *L = LoopStack.pop_back_val();
14872 llvm::append_range(LoopStack, *L);
14873
14874 // Only verify BECounts in reachable loops. For an unreachable loop,
14875 // any BECount is legal.
14876 if (!ReachableBlocks.contains(L->getHeader()))
14877 continue;
14878
14879 // Only verify cached BECounts. Computing new BECounts may change the
14880 // results of subsequent SCEV uses.
14881 auto It = BackedgeTakenCounts.find(L);
14882 if (It == BackedgeTakenCounts.end())
14883 continue;
14884
14885 auto *CurBECount =
14886 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14887 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14888
14889 if (CurBECount == SE2.getCouldNotCompute() ||
14890 NewBECount == SE2.getCouldNotCompute()) {
14891 // NB! This situation is legal, but is very suspicious -- whatever pass
14892 // change the loop to make a trip count go from could not compute to
14893 // computable or vice-versa *should have* invalidated SCEV. However, we
14894 // choose not to assert here (for now) since we don't want false
14895 // positives.
14896 continue;
14897 }
14898
14899 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14900 SE.getTypeSizeInBits(NewBECount->getType()))
14901 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14902 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14903 SE.getTypeSizeInBits(NewBECount->getType()))
14904 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14905
14906 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14907 if (Delta && !Delta->isZero()) {
14908 dbgs() << "Trip Count for " << *L << " Changed!\n";
14909 dbgs() << "Old: " << *CurBECount << "\n";
14910 dbgs() << "New: " << *NewBECount << "\n";
14911 dbgs() << "Delta: " << *Delta << "\n";
14912 std::abort();
14913 }
14914 }
14915
14916 // Collect all valid loops currently in LoopInfo.
14917 SmallPtrSet<Loop *, 32> ValidLoops;
14918 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14919 while (!Worklist.empty()) {
14920 Loop *L = Worklist.pop_back_val();
14921 if (ValidLoops.insert(L).second)
14922 Worklist.append(L->begin(), L->end());
14923 }
14924 for (const auto &KV : ValueExprMap) {
14925#ifndef NDEBUG
14926 // Check for SCEV expressions referencing invalid/deleted loops.
14927 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14928 assert(ValidLoops.contains(AR->getLoop()) &&
14929 "AddRec references invalid loop");
14930 }
14931#endif
14932
14933 // Check that the value is also part of the reverse map.
14934 auto It = ExprValueMap.find(KV.second);
14935 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14936 dbgs() << "Value " << *KV.first
14937 << " is in ValueExprMap but not in ExprValueMap\n";
14938 std::abort();
14939 }
14940
14941 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14942 if (!ReachableBlocks.contains(I->getParent()))
14943 continue;
14944 const SCEV *OldSCEV = SCM.visit(KV.second);
14945 const SCEV *NewSCEV = SE2.getSCEV(I);
14946 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14947 if (Delta && !Delta->isZero()) {
14948 dbgs() << "SCEV for value " << *I << " changed!\n"
14949 << "Old: " << *OldSCEV << "\n"
14950 << "New: " << *NewSCEV << "\n"
14951 << "Delta: " << *Delta << "\n";
14952 std::abort();
14953 }
14954 }
14955 }
14956
14957 for (const auto &KV : ExprValueMap) {
14958 for (Value *V : KV.second) {
14959 const SCEV *S = ValueExprMap.lookup(V);
14960 if (!S) {
14961 dbgs() << "Value " << *V
14962 << " is in ExprValueMap but not in ValueExprMap\n";
14963 std::abort();
14964 }
14965 if (S != KV.first) {
14966 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14967 << *KV.first << "\n";
14968 std::abort();
14969 }
14970 }
14971 }
14972
14973 // Verify integrity of SCEV users.
14974 for (const auto &S : UniqueSCEVs) {
14975 for (SCEVUse Op : S.operands()) {
14976 // We do not store dependencies of constants.
14977 if (isa<SCEVConstant>(Op))
14978 continue;
14979 auto It = SCEVUsers.find(Op);
14980 if (It != SCEVUsers.end() && It->second.count(&S))
14981 continue;
14982 dbgs() << "Use of operand " << *Op << " by user " << S
14983 << " is not being tracked!\n";
14984 std::abort();
14985 }
14986 }
14987
14988 // Verify integrity of ValuesAtScopes users.
14989 for (const auto &ValueAndVec : ValuesAtScopes) {
14990 const SCEV *Value = ValueAndVec.first;
14991 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14992 const Loop *L = LoopAndValueAtScope.first;
14993 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14994 if (!isa<SCEVConstant>(ValueAtScope)) {
14995 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14996 if (It != ValuesAtScopesUsers.end() &&
14997 is_contained(It->second, std::make_pair(L, Value)))
14998 continue;
14999 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
15000 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
15001 std::abort();
15002 }
15003 }
15004 }
15005
15006 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
15007 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
15008 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
15009 const Loop *L = LoopAndValue.first;
15010 const SCEV *Value = LoopAndValue.second;
15012 auto It = ValuesAtScopes.find(Value);
15013 if (It != ValuesAtScopes.end() &&
15014 is_contained(It->second, std::make_pair(L, ValueAtScope)))
15015 continue;
15016 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
15017 << *ValueAtScope << " missing in ValuesAtScopes\n";
15018 std::abort();
15019 }
15020 }
15021
15022 // Verify integrity of BECountUsers.
15023 auto VerifyBECountUsers = [&](bool Predicated) {
15024 auto &BECounts =
15025 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
15026 for (const auto &LoopAndBEInfo : BECounts) {
15027 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
15028 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
15029 if (!isa<SCEVConstant>(S)) {
15030 auto UserIt = BECountUsers.find(S);
15031 if (UserIt != BECountUsers.end() &&
15032 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
15033 continue;
15034 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
15035 << " missing from BECountUsers\n";
15036 std::abort();
15037 }
15038 }
15039 }
15040 }
15041 };
15042 VerifyBECountUsers(/* Predicated */ false);
15043 VerifyBECountUsers(/* Predicated */ true);
15044
15045 // Verify intergity of loop disposition cache.
15046 for (auto &[S, Values] : LoopDispositions) {
15047 for (auto [Loop, CachedDisposition] : Values) {
15048 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
15049 if (CachedDisposition != RecomputedDisposition) {
15050 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
15051 << " is incorrect: cached " << CachedDisposition << ", actual "
15052 << RecomputedDisposition << "\n";
15053 std::abort();
15054 }
15055 }
15056 }
15057
15058 // Verify integrity of the block disposition cache.
15059 for (auto &[S, Values] : BlockDispositions) {
15060 for (auto [BB, CachedDisposition] : Values) {
15061 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
15062 if (CachedDisposition != RecomputedDisposition) {
15063 dbgs() << "Cached disposition of " << *S << " for block %"
15064 << BB->getName() << " is incorrect: cached " << CachedDisposition
15065 << ", actual " << RecomputedDisposition << "\n";
15066 std::abort();
15067 }
15068 }
15069 }
15070
15071 // Verify FoldCache/FoldCacheUser caches.
15072 for (auto [FoldID, Expr] : FoldCache) {
15073 auto I = FoldCacheUser.find(Expr);
15074 if (I == FoldCacheUser.end()) {
15075 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
15076 << "!\n";
15077 std::abort();
15078 }
15079 if (!is_contained(I->second, FoldID)) {
15080 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
15081 std::abort();
15082 }
15083 }
15084 for (auto [Expr, IDs] : FoldCacheUser) {
15085 for (auto &FoldID : IDs) {
15086 const SCEV *S = FoldCache.lookup(FoldID);
15087 if (!S) {
15088 dbgs() << "Missing entry in FoldCache for expression " << *Expr
15089 << "!\n";
15090 std::abort();
15091 }
15092 if (S != Expr) {
15093 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
15094 << " != " << *Expr << "!\n";
15095 std::abort();
15096 }
15097 }
15098 }
15099
15100 // Verify that ConstantMultipleCache computations are correct. We check that
15101 // cached multiples and recomputed multiples are multiples of each other to
15102 // verify correctness. It is possible that a recomputed multiple is different
15103 // from the cached multiple due to strengthened no wrap flags or changes in
15104 // KnownBits computations.
15105 for (auto [S, Multiple] : ConstantMultipleCache) {
15106 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
15107 if ((Multiple != 0 && RecomputedMultiple != 0 &&
15108 Multiple.urem(RecomputedMultiple) != 0 &&
15109 RecomputedMultiple.urem(Multiple) != 0)) {
15110 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
15111 << *S << " : Computed " << RecomputedMultiple
15112 << " but cache contains " << Multiple << "!\n";
15113 std::abort();
15114 }
15115 }
15116}
15117
15119 Function &F, const PreservedAnalyses &PA,
15120 FunctionAnalysisManager::Invalidator &Inv) {
15121 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
15122 // of its dependencies is invalidated.
15123 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
15124 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
15125 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
15126 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
15127 Inv.invalidate<LoopAnalysis>(F, PA);
15128}
15129
15130AnalysisKey ScalarEvolutionAnalysis::Key;
15131
15134 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
15135 auto &AC = AM.getResult<AssumptionAnalysis>(F);
15136 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
15137 auto &LI = AM.getResult<LoopAnalysis>(F);
15138 return ScalarEvolution(F, TLI, AC, DT, LI);
15139}
15140
15146
15149 // For compatibility with opt's -analyze feature under legacy pass manager
15150 // which was not ported to NPM. This keeps tests using
15151 // update_analyze_test_checks.py working.
15152 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15153 << F.getName() << "':\n";
15155 return PreservedAnalyses::all();
15156}
15157
15159 "Scalar Evolution Analysis", false, true)
15165 "Scalar Evolution Analysis", false, true)
15166
15168
15170
15172 SE.reset(new ScalarEvolution(
15174 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15176 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15177 return false;
15178}
15179
15181
15183 SE->print(OS);
15184}
15185
15187 if (!VerifySCEV)
15188 return;
15189
15190 SE->verify();
15191}
15192
15200
15202 const SCEV *RHS) {
15203 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15204}
15205
15206const SCEVPredicate *
15208 const SCEV *LHS, const SCEV *RHS) {
15210 assert(LHS->getType() == RHS->getType() &&
15211 "Type mismatch between LHS and RHS");
15212 // Unique this node based on the arguments
15213 ID.AddInteger(SCEVPredicate::P_Compare);
15214 ID.AddInteger(Pred);
15215 ID.AddPointer(LHS);
15216 ID.AddPointer(RHS);
15217 void *IP = nullptr;
15218 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15219 return S;
15220 SCEVComparePredicate *Eq = new (SCEVAllocator)
15221 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15222 UniquePreds.InsertNode(Eq, IP);
15223 return Eq;
15224}
15225
15227 const SCEVAddRecExpr *AR,
15230 // Unique this node based on the arguments
15231 ID.AddInteger(SCEVPredicate::P_Wrap);
15232 ID.AddPointer(AR);
15233 ID.AddInteger(AddedFlags);
15234 void *IP = nullptr;
15235 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15236 return S;
15237 auto *OF = new (SCEVAllocator)
15238 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15239 UniquePreds.InsertNode(OF, IP);
15240 return OF;
15241}
15242
15243namespace {
15244
15245class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15246public:
15247
15248 /// Rewrites \p S in the context of a loop L and the SCEV predication
15249 /// infrastructure.
15250 ///
15251 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15252 /// equivalences present in \p Pred.
15253 ///
15254 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15255 /// \p NewPreds such that the result will be an AddRecExpr.
15256 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15258 const SCEVPredicate *Pred) {
15259 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15260 return Rewriter.visit(S);
15261 }
15262
15263 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15264 if (Pred) {
15265 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15266 for (const auto *Pred : U->getPredicates())
15267 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15268 if (IPred->getLHS() == Expr &&
15269 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15270 return IPred->getRHS();
15271 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15272 if (IPred->getLHS() == Expr &&
15273 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15274 return IPred->getRHS();
15275 }
15276 }
15277 return convertToAddRecWithPreds(Expr);
15278 }
15279
15280 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15281 const SCEV *Operand = visit(Expr->getOperand());
15282 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15283 if (AR && AR->getLoop() == L && AR->isAffine()) {
15284 // This couldn't be folded because the operand didn't have the nuw
15285 // flag. Add the nusw flag as an assumption that we could make.
15286 const SCEV *Step = AR->getStepRecurrence(SE);
15287 Type *Ty = Expr->getType();
15288 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15289 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15290 SE.getSignExtendExpr(Step, Ty), L,
15291 AR->getNoWrapFlags());
15292 }
15293 return SE.getZeroExtendExpr(Operand, Expr->getType());
15294 }
15295
15296 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15297 const SCEV *Operand = visit(Expr->getOperand());
15298 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15299 if (AR && AR->getLoop() == L && AR->isAffine()) {
15300 // This couldn't be folded because the operand didn't have the nsw
15301 // flag. Add the nssw flag as an assumption that we could make.
15302 const SCEV *Step = AR->getStepRecurrence(SE);
15303 Type *Ty = Expr->getType();
15304 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15305 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15306 SE.getSignExtendExpr(Step, Ty), L,
15307 AR->getNoWrapFlags());
15308 }
15309 return SE.getSignExtendExpr(Operand, Expr->getType());
15310 }
15311
15312private:
15313 explicit SCEVPredicateRewriter(
15314 const Loop *L, ScalarEvolution &SE,
15315 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15316 const SCEVPredicate *Pred)
15317 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15318
15319 bool addOverflowAssumption(const SCEVPredicate *P) {
15320 if (!NewPreds) {
15321 // Check if we've already made this assumption.
15322 return Pred && Pred->implies(P, SE);
15323 }
15324 NewPreds->push_back(P);
15325 return true;
15326 }
15327
15328 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15330 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15331 return addOverflowAssumption(A);
15332 }
15333
15334 // If \p Expr represents a PHINode, we try to see if it can be represented
15335 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15336 // to add this predicate as a runtime overflow check, we return the AddRec.
15337 // If \p Expr does not meet these conditions (is not a PHI node, or we
15338 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15339 // return \p Expr.
15340 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15341 if (!isa<PHINode>(Expr->getValue()))
15342 return Expr;
15343 std::optional<
15344 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15345 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15346 if (!PredicatedRewrite)
15347 return Expr;
15348 for (const auto *P : PredicatedRewrite->second){
15349 // Wrap predicates from outer loops are not supported.
15350 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15351 if (L != WP->getExpr()->getLoop())
15352 return Expr;
15353 }
15354 if (!addOverflowAssumption(P))
15355 return Expr;
15356 }
15357 return PredicatedRewrite->first;
15358 }
15359
15360 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15361 const SCEVPredicate *Pred;
15362 const Loop *L;
15363};
15364
15365} // end anonymous namespace
15366
15367const SCEV *
15369 const SCEVPredicate &Preds) {
15370 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15371}
15372
15374 const SCEV *S, const Loop *L,
15377 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15378 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15379
15380 if (!AddRec)
15381 return nullptr;
15382
15383 // Check if any of the transformed predicates is known to be false. In that
15384 // case, it doesn't make sense to convert to a predicated AddRec, as the
15385 // versioned loop will never execute.
15386 for (const SCEVPredicate *Pred : TransformPreds) {
15387 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15388 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15389 continue;
15390
15391 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15392 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15393 if (isa<SCEVCouldNotCompute>(ExitCount))
15394 continue;
15395
15396 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15397 if (!Step->isOne())
15398 continue;
15399
15400 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15401 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15402 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15403 return nullptr;
15404 }
15405
15406 // Since the transformation was successful, we can now transfer the SCEV
15407 // predicates.
15408 Preds.append(TransformPreds.begin(), TransformPreds.end());
15409
15410 return AddRec;
15411}
15412
15413/// SCEV predicates
15417
15419 const ICmpInst::Predicate Pred,
15420 const SCEV *LHS, const SCEV *RHS)
15421 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15422 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15423 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15424}
15425
15427 ScalarEvolution &SE) const {
15428 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15429
15430 if (!Op)
15431 return false;
15432
15433 if (Pred != ICmpInst::ICMP_EQ)
15434 return false;
15435
15436 return Op->LHS == LHS && Op->RHS == RHS;
15437}
15438
15439bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15440
15442 if (Pred == ICmpInst::ICMP_EQ)
15443 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15444 else
15445 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15446 << *RHS << "\n";
15447
15448}
15449
15451 const SCEVAddRecExpr *AR,
15452 IncrementWrapFlags Flags)
15453 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15454
15455const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15456
15458 ScalarEvolution &SE) const {
15459 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15460 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15461 return false;
15462
15463 if (Op->AR == AR)
15464 return true;
15465
15466 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15468 return false;
15469
15470 const SCEV *Start = AR->getStart();
15471 const SCEV *OpStart = Op->AR->getStart();
15472 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15473 return false;
15474
15475 // Reject pointers to different address spaces.
15476 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15477 return false;
15478
15479 // NUSW/NSSW on a wider-type AddRec does not imply the same on a
15480 // narrower-type AddRec.
15481 if (SE.getTypeSizeInBits(AR->getType()) >
15482 SE.getTypeSizeInBits(Op->AR->getType()))
15483 return false;
15484
15485 const SCEV *Step = AR->getStepRecurrence(SE);
15486 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15487 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15488 return false;
15489
15490 // If both steps are positive, this implies N, if N's start and step are
15491 // ULE/SLE (for NSUW/NSSW) than this'.
15492 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15493 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15494 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15495
15496 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15497 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15498 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15499 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15500 : SE.getNoopOrSignExtend(Start, WiderTy);
15502 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15503 SE.isKnownPredicate(Pred, OpStart, Start);
15504}
15505
15507 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15508 IncrementWrapFlags IFlags = Flags;
15509
15510 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15511 IFlags = clearFlags(IFlags, IncrementNSSW);
15512
15513 return IFlags == IncrementAnyWrap;
15514}
15515
15516void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15517 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15519 OS << "<nusw>";
15521 OS << "<nssw>";
15522 OS << "\n";
15523}
15524
15527 ScalarEvolution &SE) {
15528 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15529 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15530
15531 // We can safely transfer the NSW flag as NSSW.
15532 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15533 ImpliedFlags = IncrementNSSW;
15534
15535 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15536 // If the increment is positive, the SCEV NUW flag will also imply the
15537 // WrapPredicate NUSW flag.
15538 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15539 if (Step->getValue()->getValue().isNonNegative())
15540 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15541 }
15542
15543 return ImpliedFlags;
15544}
15545
15546/// Union predicates don't get cached so create a dummy set ID for it.
15548 ScalarEvolution &SE)
15549 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15550 for (const auto *P : Preds)
15551 add(P, SE);
15552}
15553
15555 return all_of(Preds,
15556 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15557}
15558
15560 ScalarEvolution &SE) const {
15561 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15562 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15563 return this->implies(I, SE);
15564 });
15565
15566 if (any_of(Preds,
15567 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); }))
15568 return true;
15569
15570 // A wrap predicate may be implied by a wrap predicate in Preds after applying
15571 // equal predicates.
15572 const auto *NWrap = dyn_cast<SCEVWrapPredicate>(N);
15573 if (!NWrap)
15574 return false;
15575 const Loop *L = NWrap->getExpr()->getLoop();
15576 return any_of(Preds, [&](const SCEVPredicate *I) {
15577 const auto *IWrap = dyn_cast<SCEVWrapPredicate>(I);
15578 if (!IWrap)
15579 return false;
15580 const auto *RewrittenAR = dyn_cast<SCEVAddRecExpr>(
15581 SE.rewriteUsingPredicate(IWrap->getExpr(), L, *this));
15582 return RewrittenAR &&
15583 SE.getWrapPredicate(RewrittenAR, IWrap->getFlags())->implies(N, SE);
15584 });
15585}
15586
15588 for (const auto *Pred : Preds)
15589 Pred->print(OS, Depth);
15590}
15591
15592void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15593 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15594 for (const auto *Pred : Set->Preds)
15595 add(Pred, SE);
15596 return;
15597 }
15598
15599 // Implication checks are quadratic in the number of predicates. Stop doing
15600 // them if there are many predicates, as they should be too expensive to use
15601 // anyway at that point.
15602 bool CheckImplies = Preds.size() < 16;
15603
15604 // Only add predicate if it is not already implied by this union predicate.
15605 if (CheckImplies && implies(N, SE))
15606 return;
15607
15608 // Build a new vector containing the current predicates, except the ones that
15609 // are implied by the new predicate N.
15611 for (auto *P : Preds) {
15612 if (CheckImplies && N->implies(P, SE))
15613 continue;
15614 PrunedPreds.push_back(P);
15615 }
15616 Preds = std::move(PrunedPreds);
15617 Preds.push_back(N);
15618}
15619
15621 Loop &L)
15622 : SE(SE), L(L) {
15624 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15625}
15626
15629 for (const auto *Op : Ops)
15630 // We do not expect that forgetting cached data for SCEVConstants will ever
15631 // open any prospects for sharpening or introduce any correctness issues,
15632 // so we don't bother storing their dependencies.
15633 if (!isa<SCEVConstant>(Op))
15634 SCEVUsers[Op].insert(User);
15635}
15636
15638 for (const SCEV *Op : Ops)
15639 // We do not expect that forgetting cached data for SCEVConstants will ever
15640 // open any prospects for sharpening or introduce any correctness issues,
15641 // so we don't bother storing their dependencies.
15642 if (!isa<SCEVConstant>(Op))
15643 SCEVUsers[Op].insert(User);
15644}
15645
15647 const SCEV *Expr = SE.getSCEV(V);
15648 return getPredicatedSCEV(Expr);
15649}
15650
15652 RewriteEntry &Entry = RewriteMap[Expr];
15653
15654 // If we already have an entry and the version matches, return it.
15655 if (Entry.second && Generation == Entry.first)
15656 return Entry.second;
15657
15658 // We found an entry but it's stale. Rewrite the stale entry
15659 // according to the current predicate.
15660 if (Entry.second)
15661 Expr = Entry.second;
15662
15663 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15664 Entry = {Generation, NewSCEV};
15665
15666 return NewSCEV;
15667}
15668
15670 if (!BackedgeCount) {
15672 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15673 for (const auto *P : Preds)
15674 addPredicate(*P);
15675 }
15676 return BackedgeCount;
15677}
15678
15680 if (!SymbolicMaxBackedgeCount) {
15682 SymbolicMaxBackedgeCount =
15683 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15684 for (const auto *P : Preds)
15685 addPredicate(*P);
15686 }
15687 return SymbolicMaxBackedgeCount;
15688}
15689
15691 if (!SmallConstantMaxTripCount) {
15693 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15694 for (const auto *P : Preds)
15695 addPredicate(*P);
15696 }
15697 return *SmallConstantMaxTripCount;
15698}
15699
15701 if (Preds->implies(&Pred, SE))
15702 return;
15703
15704 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15705 NewPreds.push_back(&Pred);
15706 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15707 updateGeneration();
15708}
15709
15712 for (const SCEVPredicate *P : Preds)
15713 addPredicate(*P);
15714}
15715
15717 return *Preds;
15718}
15719
15720void PredicatedScalarEvolution::updateGeneration() {
15721 // If the generation number wrapped recompute everything.
15722 if (++Generation == 0) {
15723 for (auto &II : RewriteMap) {
15724 const SCEV *Rewritten = II.second.second;
15725 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15726 }
15727 }
15728}
15729
15732 const auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V));
15733 if (!AR)
15734 return false;
15735
15737 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15738
15740}
15741
15744 const SCEV *Expr = this->getSCEV(V);
15746 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15747
15748 if (!New)
15749 return nullptr;
15750
15751 if (ExtraPreds) {
15752 ExtraPreds->append(NewPreds);
15753 return New;
15754 }
15755
15756 addPredicates(NewPreds);
15757
15758 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15759 return New;
15760}
15761
15764 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15765 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15766 SE)),
15767 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {}
15768
15770 // For each block.
15771 for (auto *BB : L.getBlocks())
15772 for (auto &I : *BB) {
15773 if (!SE.isSCEVable(I.getType()))
15774 continue;
15775
15776 auto *Expr = SE.getSCEV(&I);
15777 auto II = RewriteMap.find(Expr);
15778
15779 if (II == RewriteMap.end())
15780 continue;
15781
15782 // Don't print things that are not interesting.
15783 if (II->second.second == Expr)
15784 continue;
15785
15786 OS.indent(Depth) << "[PSE]" << I << ":\n";
15787 OS.indent(Depth + 2) << *Expr << "\n";
15788 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15789 }
15790}
15791
15794 BasicBlock *Header = L->getHeader();
15795 BasicBlock *Pred = L->getLoopPredecessor();
15796 LoopGuards Guards(SE);
15797 if (!Pred)
15798 return Guards;
15800 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15801 return Guards;
15802}
15803
15804void ScalarEvolution::LoopGuards::collectFromPHI(
15808 unsigned Depth) {
15809 if (!SE.isSCEVable(Phi.getType()))
15810 return;
15811
15812 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15813 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15814 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15815 if (!VisitedBlocks.insert(InBlock).second)
15816 return {nullptr, scCouldNotCompute};
15817
15818 // Avoid analyzing unreachable blocks so that we don't get trapped
15819 // traversing cycles with ill-formed dominance or infinite cycles
15820 if (!SE.DT.isReachableFromEntry(InBlock))
15821 return {nullptr, scCouldNotCompute};
15822
15823 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15824 if (Inserted)
15825 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15826 Depth + 1);
15827 auto &RewriteMap = G->second.RewriteMap;
15828 if (RewriteMap.empty())
15829 return {nullptr, scCouldNotCompute};
15830 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15831 if (S == RewriteMap.end())
15832 return {nullptr, scCouldNotCompute};
15833 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15834 if (!SM)
15835 return {nullptr, scCouldNotCompute};
15836 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15837 return {C0, SM->getSCEVType()};
15838 return {nullptr, scCouldNotCompute};
15839 };
15840 auto MergeMinMaxConst = [](MinMaxPattern P1,
15841 MinMaxPattern P2) -> MinMaxPattern {
15842 auto [C1, T1] = P1;
15843 auto [C2, T2] = P2;
15844 if (!C1 || !C2 || T1 != T2)
15845 return {nullptr, scCouldNotCompute};
15846 switch (T1) {
15847 case scUMaxExpr:
15848 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15849 case scSMaxExpr:
15850 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15851 case scUMinExpr:
15852 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15853 case scSMinExpr:
15854 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15855 default:
15856 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15857 }
15858 };
15859 auto P = GetMinMaxConst(0);
15860 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15861 if (!P.first)
15862 break;
15863 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15864 }
15865 if (P.first) {
15866 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15867 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15868 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15869 Guards.RewriteMap.insert({LHS, RHS});
15870 }
15871}
15872
15873// Return a new SCEV that modifies \p Expr to the closest number divides by
15874// \p Divisor and less or equal than Expr. For now, only handle constant
15875// Expr.
15877 const APInt &DivisorVal,
15878 ScalarEvolution &SE) {
15879 const APInt *ExprVal;
15880 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15881 DivisorVal.isNonPositive())
15882 return Expr;
15883 APInt Rem = ExprVal->urem(DivisorVal);
15884 // return the SCEV: Expr - Expr % Divisor
15885 return SE.getConstant(*ExprVal - Rem);
15886}
15887
15888// Return a new SCEV that modifies \p Expr to the closest number divides by
15889// \p Divisor and greater or equal than Expr. For now, only handle constant
15890// Expr.
15891static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15892 const APInt &DivisorVal,
15893 ScalarEvolution &SE) {
15894 const APInt *ExprVal;
15895 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15896 DivisorVal.isNonPositive())
15897 return Expr;
15898 APInt Rem = ExprVal->urem(DivisorVal);
15899 if (Rem.isZero())
15900 return Expr;
15901 // return the SCEV: Expr + Divisor - Expr % Divisor
15902 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15903}
15904
15906 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15909 // If we have LHS == 0, check if LHS is computing a property of some unknown
15910 // SCEV %v which we can rewrite %v to express explicitly.
15912 return false;
15913 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15914 // explicitly express that.
15915 const SCEVUnknown *URemLHS = nullptr;
15916 const SCEV *URemRHS = nullptr;
15917 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15918 return false;
15919
15920 const SCEV *Multiple =
15921 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15922 DivInfo[URemLHS] = Multiple;
15923 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15924 Multiples[URemLHS] = C->getAPInt();
15925 return true;
15926}
15927
15928// Check if the condition is a divisibility guard (A % B == 0).
15929static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15930 ScalarEvolution &SE) {
15931 const SCEV *X, *Y;
15932 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15933}
15934
15935// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15936// recursively. This is done by aligning up/down the constant value to the
15937// Divisor.
15938static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15939 APInt Divisor,
15940 ScalarEvolution &SE) {
15941 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15942 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15943 // the non-constant operand and in \p LHS the constant operand.
15944 auto IsMinMaxSCEVWithNonNegativeConstant =
15945 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15946 const SCEV *&RHS) {
15947 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15948 if (MinMax->getNumOperands() != 2)
15949 return false;
15950 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15951 if (C->getAPInt().isNegative())
15952 return false;
15953 SCTy = MinMax->getSCEVType();
15954 LHS = MinMax->getOperand(0);
15955 RHS = MinMax->getOperand(1);
15956 return true;
15957 }
15958 }
15959 return false;
15960 };
15961
15962 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15963 SCEVTypes SCTy;
15964 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15965 MinMaxRHS))
15966 return MinMaxExpr;
15967 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15968 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15969 auto *DivisibleExpr =
15970 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15971 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15973 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15974 return SE.getMinMaxExpr(SCTy, Ops);
15975}
15976
15977void ScalarEvolution::LoopGuards::collectFromBlock(
15978 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15979 const BasicBlock *Block, const BasicBlock *Pred,
15980 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15981
15983
15984 SmallVector<SCEVUse> ExprsToRewrite;
15985 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15986 const SCEV *RHS,
15987 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15988 const LoopGuards &DivGuards) {
15989 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15990 // replacement SCEV which isn't directly implied by the structure of that
15991 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15992 // legal. See the scoping rules for flags in the header to understand why.
15993
15994 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15995 // create this form when combining two checks of the form (X u< C2 + C1) and
15996 // (X >=u C1).
15997 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15998 &ExprsToRewrite]() {
15999 const SCEVConstant *C1;
16000 const SCEVUnknown *LHSUnknown;
16001 auto *C2 = dyn_cast<SCEVConstant>(RHS);
16002 if (!match(LHS,
16003 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
16004 !C2)
16005 return false;
16006
16007 auto ExactRegion =
16008 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
16009 .sub(C1->getAPInt());
16010
16011 // Bail out, unless we have a non-wrapping, monotonic range.
16012 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
16013 return false;
16014 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
16015 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
16016 I->second = SE.getUMaxExpr(
16017 SE.getConstant(ExactRegion.getUnsignedMin()),
16018 SE.getUMinExpr(RewrittenLHS,
16019 SE.getConstant(ExactRegion.getUnsignedMax())));
16020 ExprsToRewrite.push_back(LHSUnknown);
16021 return true;
16022 };
16023 if (MatchRangeCheckIdiom())
16024 return;
16025
16026 // Do not apply information for constants or if RHS contains an AddRec.
16028 return;
16029
16030 // If RHS is SCEVUnknown, make sure the information is applied to it.
16032 std::swap(LHS, RHS);
16034 }
16035
16036 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
16037 // and \p FromRewritten are the same (i.e. there has been no rewrite
16038 // registered for \p From), then puts this value in the list of rewritten
16039 // expressions.
16040 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
16041 const SCEV *To) {
16042 if (From == FromRewritten)
16043 ExprsToRewrite.push_back(From);
16044 RewriteMap[From] = To;
16045 };
16046
16047 // Checks whether \p S has already been rewritten. In that case returns the
16048 // existing rewrite because we want to chain further rewrites onto the
16049 // already rewritten value. Otherwise returns \p S.
16050 auto GetMaybeRewritten = [&](const SCEV *S) {
16051 return RewriteMap.lookup_or(S, S);
16052 };
16053
16054 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
16055 // Apply divisibility information when computing the constant multiple.
16056 const APInt &DividesBy =
16057 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
16058
16059 // Collect rewrites for LHS and its transitive operands based on the
16060 // condition.
16061 // For min/max expressions, also apply the guard to its operands:
16062 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
16063 // 'min(a, b) > c' -> '(a > c) and (b > c)',
16064 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
16065 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
16066
16067 // We cannot express strict predicates in SCEV, so instead we replace them
16068 // with non-strict ones against plus or minus one of RHS depending on the
16069 // predicate.
16070 const SCEV *One = SE.getOne(RHS->getType());
16071 switch (Predicate) {
16072 case CmpInst::ICMP_ULT:
16073 if (RHS->getType()->isPointerTy())
16074 return;
16075 RHS = SE.getUMaxExpr(RHS, One);
16076 [[fallthrough]];
16077 case CmpInst::ICMP_SLT: {
16078 RHS = SE.getMinusSCEV(RHS, One);
16079 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16080 break;
16081 }
16082 case CmpInst::ICMP_UGT:
16083 case CmpInst::ICMP_SGT:
16084 RHS = SE.getAddExpr(RHS, One);
16085 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16086 break;
16087 case CmpInst::ICMP_ULE:
16088 case CmpInst::ICMP_SLE:
16089 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16090 break;
16091 case CmpInst::ICMP_UGE:
16092 case CmpInst::ICMP_SGE:
16093 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16094 break;
16095 default:
16096 break;
16097 }
16098
16099 SmallVector<SCEVUse, 16> Worklist(1, LHS);
16100 SmallPtrSet<const SCEV *, 16> Visited;
16101
16102 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
16103 append_range(Worklist, S->operands());
16104 };
16105
16106 while (!Worklist.empty()) {
16107 const SCEV *From = Worklist.pop_back_val();
16108 if (isa<SCEVConstant>(From))
16109 continue;
16110 if (!Visited.insert(From).second)
16111 continue;
16112 const SCEV *FromRewritten = GetMaybeRewritten(From);
16113 const SCEV *To = nullptr;
16114
16115 switch (Predicate) {
16116 case CmpInst::ICMP_ULT:
16117 case CmpInst::ICMP_ULE:
16118 To = SE.getUMinExpr(FromRewritten, RHS);
16119 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
16120 EnqueueOperands(UMax);
16121 break;
16122 case CmpInst::ICMP_SLT:
16123 case CmpInst::ICMP_SLE:
16124 To = SE.getSMinExpr(FromRewritten, RHS);
16125 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
16126 EnqueueOperands(SMax);
16127 break;
16128 case CmpInst::ICMP_UGT:
16129 case CmpInst::ICMP_UGE:
16130 To = SE.getUMaxExpr(FromRewritten, RHS);
16131 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
16132 EnqueueOperands(UMin);
16133 break;
16134 case CmpInst::ICMP_SGT:
16135 case CmpInst::ICMP_SGE:
16136 To = SE.getSMaxExpr(FromRewritten, RHS);
16137 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
16138 EnqueueOperands(SMin);
16139 break;
16140 case CmpInst::ICMP_EQ:
16142 To = RHS;
16143 break;
16144 case CmpInst::ICMP_NE:
16145 if (match(RHS, m_scev_Zero())) {
16146 const SCEV *OneAlignedUp =
16147 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16148 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16149 } else {
16150 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16151 // but creating the subtraction eagerly is expensive. Track the
16152 // inequalities in a separate map, and materialize the rewrite lazily
16153 // when encountering a suitable subtraction while re-writing.
16154 if (LHS->getType()->isPointerTy()) {
16158 break;
16159 }
16160 const SCEVConstant *C;
16161 const SCEV *A, *B;
16164 RHS = A;
16165 LHS = B;
16166 }
16167 if (LHS > RHS)
16168 std::swap(LHS, RHS);
16169 Guards.NotEqual.insert({LHS, RHS});
16170 continue;
16171 }
16172 break;
16173 default:
16174 break;
16175 }
16176
16177 if (To)
16178 AddRewrite(From, FromRewritten, To);
16179 }
16180 };
16181
16183 // First, collect information from assumptions dominating the loop.
16184 for (auto &AssumeVH : SE.AC.assumptions()) {
16185 if (!AssumeVH)
16186 continue;
16187 auto *AssumeI = cast<CallInst>(AssumeVH);
16188 if (!SE.DT.dominates(AssumeI, Block))
16189 continue;
16190 Terms.emplace_back(AssumeI->getOperand(0), true);
16191 }
16192
16193 // Second, collect information from llvm.experimental.guards dominating the loop.
16194 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16195 SE.F.getParent(), Intrinsic::experimental_guard);
16196 if (GuardDecl)
16197 for (const auto *GU : GuardDecl->users())
16198 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16199 if (Guard->getFunction() == Block->getParent() &&
16200 SE.DT.dominates(Guard, Block))
16201 Terms.emplace_back(Guard->getArgOperand(0), true);
16202
16203 // Third, collect conditions from dominating branches. Starting at the loop
16204 // predecessor, climb up the predecessor chain, as long as there are
16205 // predecessors that can be found that have unique successors leading to the
16206 // original header.
16207 // TODO: share this logic with isLoopEntryGuardedByCond.
16208 unsigned NumCollectedConditions = 0;
16210 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16211 for (; Pair.first;
16212 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16213 VisitedBlocks.insert(Pair.second);
16214 const CondBrInst *LoopEntryPredicate =
16215 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16216 if (!LoopEntryPredicate)
16217 continue;
16218
16219 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16220 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16221 NumCollectedConditions++;
16222
16223 // If we are recursively collecting guards stop after 2
16224 // conditions to limit compile-time impact for now.
16225 if (Depth > 0 && NumCollectedConditions == 2)
16226 break;
16227 }
16228 // Finally, if we stopped climbing the predecessor chain because
16229 // there wasn't a unique one to continue, try to collect conditions
16230 // for PHINodes by recursively following all of their incoming
16231 // blocks and try to merge the found conditions to build a new one
16232 // for the Phi.
16233 if (Pair.second->hasNPredecessorsOrMore(2) &&
16235 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16236 for (auto &Phi : Pair.second->phis())
16237 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16238 }
16239
16240 // Now apply the information from the collected conditions to
16241 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16242 // earliest conditions is processed first, except guards with divisibility
16243 // information, which are moved to the back. This ensures the SCEVs with the
16244 // shortest dependency chains are constructed first.
16246 GuardsToProcess;
16247 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16248 SmallVector<Value *, 8> Worklist;
16249 SmallPtrSet<Value *, 8> Visited;
16250 Worklist.push_back(Term);
16251 while (!Worklist.empty()) {
16252 Value *Cond = Worklist.pop_back_val();
16253 if (!Visited.insert(Cond).second)
16254 continue;
16255
16256 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16257 auto Predicate =
16258 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16259 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16260 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16261 // If LHS is a constant, apply information to the other expression.
16262 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16263 // can improve results.
16264 if (isa<SCEVConstant>(LHS)) {
16265 std::swap(LHS, RHS);
16267 }
16268 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16269 continue;
16270 }
16271
16272 Value *L, *R;
16273 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16274 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16275 Worklist.push_back(L);
16276 Worklist.push_back(R);
16277 }
16278 }
16279 }
16280
16281 // Process divisibility guards in reverse order to populate DivGuards early.
16282 DenseMap<const SCEV *, APInt> Multiples;
16283 LoopGuards DivGuards(SE);
16284 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16285 if (!isDivisibilityGuard(LHS, RHS, SE))
16286 continue;
16287 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16288 Multiples, SE);
16289 }
16290
16291 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16292 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16293
16294 // Apply divisibility information last. This ensures it is applied to the
16295 // outermost expression after other rewrites for the given value.
16296 for (const auto &[K, Divisor] : Multiples) {
16297 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16298 Guards.RewriteMap[K] =
16300 Guards.rewrite(K), Divisor, SE),
16301 DivisorSCEV),
16302 DivisorSCEV);
16303 ExprsToRewrite.push_back(K);
16304 }
16305
16306 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16307 // the replacement expressions are contained in the ranges of the replaced
16308 // expressions.
16309 Guards.PreserveNUW = true;
16310 Guards.PreserveNSW = true;
16311 for (const SCEV *Expr : ExprsToRewrite) {
16312 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16313 Guards.PreserveNUW &=
16314 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16315 Guards.PreserveNSW &=
16316 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16317 }
16318
16319 // Now that all rewrite information is collect, rewrite the collected
16320 // expressions with the information in the map. This applies information to
16321 // sub-expressions.
16322 if (ExprsToRewrite.size() > 1) {
16323 for (const SCEV *Expr : ExprsToRewrite) {
16324 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16325 Guards.RewriteMap.erase(Expr);
16326 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16327 }
16328 }
16329}
16330
16332 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16333 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16334 /// replacement is loop invariant in the loop of the AddRec.
16335 class SCEVLoopGuardRewriter
16336 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16339
16341
16342 public:
16343 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16344 const ScalarEvolution::LoopGuards &Guards)
16345 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16346 NotEqual(Guards.NotEqual) {
16347 if (Guards.PreserveNUW)
16348 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16349 if (Guards.PreserveNSW)
16350 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16351 }
16352
16353 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16354
16355 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16356 return Map.lookup_or(Expr, Expr);
16357 }
16358
16359 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16360 if (const SCEV *S = Map.lookup(Expr))
16361 return S;
16362
16363 // If we didn't find the extact ZExt expr in the map, check if there's
16364 // an entry for a smaller ZExt we can use instead.
16365 Type *Ty = Expr->getType();
16366 const SCEV *Op = Expr->getOperand(0);
16367 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16368 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16369 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16370 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16371 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16372 if (const SCEV *S = Map.lookup(NarrowExt))
16373 return SE.getZeroExtendExpr(S, Ty);
16374 Bitwidth = Bitwidth / 2;
16375 }
16376
16378 Expr);
16379 }
16380
16381 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16382 if (const SCEV *S = Map.lookup(Expr))
16383 return S;
16385 Expr);
16386 }
16387
16388 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16389 if (const SCEV *S = Map.lookup(Expr))
16390 return S;
16392 }
16393
16394 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16395 if (const SCEV *S = Map.lookup(Expr))
16396 return S;
16398 }
16399
16400 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16401 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16402 // return UMax(S, 1).
16403 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16404 SCEVUse LHS, RHS;
16405 if (MatchBinarySub(S, LHS, RHS)) {
16406 if (LHS > RHS)
16407 std::swap(LHS, RHS);
16408 if (NotEqual.contains({LHS, RHS})) {
16409 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16410 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16411 return SE.getUMaxExpr(OneAlignedUp, S);
16412 }
16413 }
16414 return nullptr;
16415 };
16416
16417 // Check if Expr itself is a subtraction pattern with guard info.
16418 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16419 return Rewritten;
16420
16421 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16422 // (Const + A + B). There may be guard info for A + B, and if so, apply
16423 // it.
16424 // TODO: Could more generally apply guards to Add sub-expressions.
16425 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16426 Expr->getNumOperands() == 3) {
16427 const SCEV *Add =
16428 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16429 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16430 return SE.getAddExpr(
16431 Expr->getOperand(0), Rewritten,
16432 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16433 if (const SCEV *S = Map.lookup(Add))
16434 return SE.getAddExpr(Expr->getOperand(0), S);
16435 }
16436 SmallVector<SCEVUse, 2> Operands;
16437 bool Changed = false;
16438 for (SCEVUse Op : Expr->operands()) {
16439 Operands.push_back(
16441 Changed |= Op != Operands.back();
16442 }
16443 // We are only replacing operands with equivalent values, so transfer the
16444 // flags from the original expression.
16445 return !Changed ? Expr
16446 : SE.getAddExpr(Operands,
16448 Expr->getNoWrapFlags(), FlagMask));
16449 }
16450
16451 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16452 SmallVector<SCEVUse, 2> Operands;
16453 bool Changed = false;
16454 for (SCEVUse Op : Expr->operands()) {
16455 Operands.push_back(
16457 Changed |= Op != Operands.back();
16458 }
16459 // We are only replacing operands with equivalent values, so transfer the
16460 // flags from the original expression.
16461 return !Changed ? Expr
16462 : SE.getMulExpr(Operands,
16464 Expr->getNoWrapFlags(), FlagMask));
16465 }
16466 };
16467
16468 if (RewriteMap.empty() && NotEqual.empty())
16469 return Expr;
16470
16471 SCEVLoopGuardRewriter Rewriter(SE, *this);
16472 return Rewriter.visit(Expr);
16473}
16474
16475const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16476 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16477}
16478
16480 const LoopGuards &Guards) {
16481 return Guards.rewrite(Expr);
16482}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:856
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:663
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
static constexpr Value * getValue(Ty &ValueOrUse)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
MachineInstr unsigned OpIdx
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static std::pair< ConstantRange, bool > getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2006
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:640
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1535
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:968
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1818
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1692
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
LLVM_ABI APInt uadd_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1970
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1670
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1784
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
Definition APInt.h:476
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1300
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1028
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition APInt.h:1244
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:130
size_t size() const
Get the array size.
Definition ArrayRef.h:141
iterator begin() const
Definition ArrayRef.h:129
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:409
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
@ ICMP_SLT
signed less than
Definition InstrTypes.h:769
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:770
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:764
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:763
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:767
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:765
@ ICMP_NE
not equal
Definition InstrTypes.h:762
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:768
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:766
bool isSigned() const
Definition InstrTypes.h:993
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:890
bool isTrueWhenEqual() const
This is just a convenience.
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:852
bool isUnsigned() const
Definition InstrTypes.h:999
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:989
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1497
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:791
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:252
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:301
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:135
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:238
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:221
iterator end()
Definition DenseMap.h:143
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:216
void swap(DerivedT &RHS)
Definition DenseMap.h:439
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:286
Analysis pass which computes a DominatorTree.
Definition Dominators.h:270
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:306
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:151
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This class describes a reference to an interned FoldingSetNodeID, which can be a useful to store node...
Definition FoldingSet.h:171
This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:208
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
bool isEquality() const
Return true if this predicate is either EQ or NE.
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:348
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:587
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:612
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1069
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:113
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:107
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2, ArrayRef< const SCEVPredicate * > ExtraPreds={}) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds and ExtraPreds.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've statically proved that V doesn't wrap.
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V, SmallVectorImpl< const SCEVPredicate * > *WrapPredsAdded=nullptr)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI void addPredicates(ArrayRef< const SCEVPredicate * > Preds)
Adds all predicates in Preds.
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
SCEVUnionPredicate getUnionWith(const SCEVPredicate *N, ScalarEvolution &SE) const
Returns a new SCEVUnionPredicate that is the union of this predicate and the given predicate N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static LLVM_ABI bool isGuaranteedNotToBePoison(const SCEV *Op)
Returns true if Op is guaranteed to not be poison.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI bool isLoopUniform(const SCEV *S, const Loop *L)
Returns true if the given SCEV is loop-uniform with respect to the specified loop L.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
@ LoopUniform
The SCEV is loop-uniform.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:301
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:743
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:774
TypeSize getSizeInBits() const
Definition DataLayout.h:754
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:309
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:282
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:306
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:270
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:313
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:319
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2287
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2847
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2292
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:830
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
match_combine_or< Ty... > m_CombineOr(const Ty &...Ps)
Combine pattern matchers matching any of Ps patterns.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BasicBlock()
Match an arbitrary basic block value and ignore it.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
auto m_Value()
Match an arbitrary value and ignore it.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
match_bind< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
brc_match< Cond_t, match_bind< BasicBlock >, match_bind< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
match_bind< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVAffineAddRec_match< Op0_t, Op1_t, match_isa< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:573
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2116
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
RelativeUniformCounterPtr Values
Definition InstrProf.h:91
@ Dead
Unused definition.
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2208
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2111
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
RelativeUniformCounterPtr ValuesPtrExpr VTableAddr Value
Definition InstrProf.h:143
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:204
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2200
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:377
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
constexpr T divideCeil(U Numerator, V Denominator)
Returns the integer ceil(Numerator / Denominator).
Definition MathExtras.h:394
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
RelativeUniformCounterPtr ValuesPtrExpr VTableAddr Count
Definition InstrProf.h:145
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
IntPtrTy
Definition InstrProf.h:82
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2088
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:860
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:106
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:200
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:130
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:103
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.