31#define DEBUG_TYPE "riscv-gather-scatter-lowering"
62 return "RISC-V gather/scatter lowering";
68 std::pair<Value *, Value *> determineBaseAndStride(
Instruction *Ptr,
78char RISCVGatherScatterLowering::ID = 0;
81 "RISC-V gather/scatter lowering pass",
false,
false)
84 return new RISCVGatherScatterLowering();
90 return std::make_pair(
nullptr,
nullptr);
98 return std::make_pair(
nullptr,
nullptr);
99 APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
101 for (
unsigned i = 1; i != NumElts; ++i) {
104 return std::make_pair(
nullptr,
nullptr);
108 StrideVal = LocalStride;
109 else if (StrideVal != LocalStride)
110 return std::make_pair(
nullptr,
nullptr);
115 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
117 return std::make_pair(StartVal, Stride);
129 auto *Ty = Start->getType()->getScalarType();
130 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
136 if (!BO || (BO->getOpcode() != Instruction::Add &&
137 BO->getOpcode() != Instruction::Or &&
138 BO->getOpcode() != Instruction::Shl &&
139 BO->getOpcode() != Instruction::Mul))
140 return std::make_pair(
nullptr,
nullptr);
142 if (BO->getOpcode() == Instruction::Or &&
144 return std::make_pair(
nullptr,
nullptr);
147 unsigned OtherIndex = 0;
154 return std::make_pair(
nullptr,
nullptr);
160 return std::make_pair(
nullptr,
nullptr);
162 Builder.SetInsertPoint(BO);
163 Builder.SetCurrentDebugLocation(
DebugLoc());
166 switch (BO->getOpcode()) {
169 case Instruction::Or:
170 Start = Builder.CreateOr(Start,
Splat,
"",
true);
172 case Instruction::Add:
173 Start = Builder.CreateAdd(Start,
Splat);
175 case Instruction::Mul:
176 Start = Builder.CreateMul(Start,
Splat);
177 Stride = Builder.CreateMul(Stride,
Splat);
179 case Instruction::Shl:
180 Start = Builder.CreateShl(Start,
Splat);
181 Stride = Builder.CreateShl(Stride,
Splat);
185 return std::make_pair(Start, Stride);
192bool RISCVGatherScatterLowering::matchStridedRecurrence(
Value *Index,
Loop *L,
201 if (
Phi->getParent() !=
L->getHeader())
208 assert(
Phi->getNumIncomingValues() == 2 &&
"Expected 2 operand phi.");
209 unsigned IncrementingBlock =
Phi->getIncomingValue(0) == Inc ? 0 : 1;
210 assert(
Phi->getIncomingValue(IncrementingBlock) == Inc &&
211 "Expected one operand of phi to be Inc");
221 assert(Stride !=
nullptr);
226 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->
getName() +
".scalar",
228 BasePtr->addIncoming(Start,
Phi->getIncomingBlock(1 - IncrementingBlock));
229 BasePtr->addIncoming(Inc,
Phi->getIncomingBlock(IncrementingBlock));
232 MaybeDeadPHIs.push_back(Phi);
241 switch (BO->getOpcode()) {
244 case Instruction::Or:
249 case Instruction::Add:
251 case Instruction::Shl:
253 case Instruction::Mul:
262 OtherOp = BO->getOperand(1);
267 OtherOp = BO->getOperand(0);
273 if (!
L->isLoopInvariant(OtherOp))
282 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
287 unsigned StartBlock =
BasePtr->getOperand(0) == Inc ? 1 : 0;
293 BasePtr->getIncomingBlock(StartBlock)->getTerminator());
297 switch (BO->getOpcode()) {
300 case Instruction::Add:
301 case Instruction::Or: {
307 case Instruction::Mul: {
309 Stride = Builder.
CreateMul(Stride, SplatOp,
"stride");
312 case Instruction::Shl: {
314 Stride = Builder.
CreateShl(Stride, SplatOp,
"stride");
324 switch (BO->getOpcode()) {
327 case Instruction::Mul:
328 Step = Builder.
CreateMul(Step, SplatOp,
"step");
330 case Instruction::Shl:
331 Step = Builder.
CreateShl(Step, SplatOp,
"step");
336 BasePtr->setIncomingValue(StartBlock, Start);
340std::pair<Value *, Value *>
341RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
342 IRBuilderBase &Builder) {
347 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
352 return std::make_pair(
nullptr,
nullptr);
354 auto I = StridedAddrs.find(
GEP);
355 if (
I != StridedAddrs.end())
358 SmallVector<Value *, 2>
Ops(
GEP->operands());
363 BaseInst && BaseInst->getType()->isVectorTy()) {
365 auto IsScalar = [](
Value *Idx) {
return !Idx->getType()->isVectorTy(); };
367 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
372 Builder.
CreateGEP(
GEP->getSourceElementType(), BaseBase, Indices,
373 GEP->getName() +
"offset",
GEP->isInBounds());
374 return {OffsetBase, Stride};
384 return std::make_pair(
nullptr,
nullptr);
387 std::optional<unsigned> VecOperand;
388 unsigned TypeScale = 0;
392 for (
unsigned i = 1, e =
GEP->getNumOperands(); i != e; ++i, ++GTI) {
397 return std::make_pair(
nullptr,
nullptr);
403 return std::make_pair(
nullptr,
nullptr);
410 return std::make_pair(
nullptr,
nullptr);
418 Type *VecIntPtrTy =
DL->getIntPtrType(
GEP->getType());
419 if (VecIndex->
getType() != VecIntPtrTy) {
422 return std::make_pair(
nullptr,
nullptr);
438 Type *SourceTy =
GEP->getSourceElementType();
448 Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
450 auto P = std::make_pair(BasePtr, Stride);
451 StridedAddrs[
GEP] =
P;
457 if (!L || !
L->getLoopPreheader() || !
L->getLoopLatch())
458 return std::make_pair(
nullptr,
nullptr);
462 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
463 return std::make_pair(
nullptr,
nullptr);
466 unsigned IncrementingBlock = BasePhi->
getOperand(0) == Inc ? 0 : 1;
468 "Expected one operand of phi to be Inc");
473 Ops[*VecOperand] = BasePhi;
474 Type *SourceTy =
GEP->getSourceElementType();
488 Stride = Builder.
CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
490 auto P = std::make_pair(BasePtr, Stride);
491 StridedAddrs[
GEP] =
P;
495bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *
II) {
497 Value *StoreVal =
nullptr, *Ptr, *
Mask, *EVL =
nullptr;
499 switch (
II->getIntrinsicID()) {
500 case Intrinsic::masked_gather:
502 Ptr =
II->getArgOperand(0);
503 Alignment =
II->getParamAlign(0).valueOrOne();
504 Mask =
II->getArgOperand(1);
506 case Intrinsic::vp_gather:
508 Ptr =
II->getArgOperand(0);
510 Alignment =
II->getParamAlign(0).value_or(
511 DL->getABITypeAlign(DataType->getElementType()));
512 Mask =
II->getArgOperand(1);
513 EVL =
II->getArgOperand(2);
515 case Intrinsic::masked_scatter:
517 StoreVal =
II->getArgOperand(0);
518 Ptr =
II->getArgOperand(1);
519 Alignment =
II->getParamAlign(1).valueOrOne();
520 Mask =
II->getArgOperand(2);
522 case Intrinsic::vp_scatter:
524 StoreVal =
II->getArgOperand(0);
525 Ptr =
II->getArgOperand(1);
527 Alignment =
II->getParamAlign(1).value_or(
528 DL->getABITypeAlign(DataType->getElementType()));
529 Mask =
II->getArgOperand(2);
530 EVL =
II->getArgOperand(3);
550 LLVMContext &Ctx = PtrI->getContext();
555 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
558 assert(Stride !=
nullptr);
570 Intrinsic::experimental_vp_strided_load,
575 if (
II->getIntrinsicID() == Intrinsic::masked_gather)
579 Intrinsic::experimental_vp_strided_store,
584 II->replaceAllUsesWith(
Call);
585 II->eraseFromParent();
587 if (PtrI->use_empty())
593bool RISCVGatherScatterLowering::runOnFunction(Function &
F) {
597 auto &TPC = getAnalysis<TargetPassConfig>();
598 auto &TM = TPC.getTM<RISCVTargetMachine>();
599 ST = &TM.getSubtarget<RISCVSubtarget>(
F);
600 if (!
ST->hasVInstructions() || !
ST->useRVVForFixedLengthVectors())
603 TLI =
ST->getTargetLowering();
604 DL = &
F.getDataLayout();
605 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
607 StridedAddrs.clear();
613 for (BasicBlock &BB :
F) {
614 for (Instruction &
I : BB) {
618 switch (
II->getIntrinsicID()) {
619 case Intrinsic::masked_gather:
620 case Intrinsic::masked_scatter:
621 case Intrinsic::vp_gather:
622 case Intrinsic::vp_scatter:
632 for (
auto *
II : Worklist)
633 Changed |= tryCreateStridedLoadStore(
II);
636 while (!MaybeDeadPHIs.empty()) {
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static bool runOnFunction(Function &F, bool PostInlining)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
static std::pair< Value *, Value * > matchStridedStart(Value *Start, IRBuilderBase &Builder)
static std::pair< Value *, Value * > matchStridedConstant(Constant *StartC)
static SymbolRef::Type getType(const Symbol *Sym)
Target-Independent Code Generator Pass Configuration Options pass.
Class for arbitrary precision integers.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
LLVM_ABI void setPreservesCFG()
This function should be called by the pass, iff they do not:
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
BinaryOps getOpcode() const
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
This is an important base class in LLVM.
LLVM_ABI Constant * getAggregateElement(unsigned Elt) const
For aggregates (struct/array/vector) return the constant that corresponds to the specified element if...
A parsed version of the target data layout string in and methods for querying it.
FunctionPass class - This class is used to implement most global optimizations.
Common base class shared among various IRBuilders.
LLVM_ABI Value * CreateSelect(Value *C, Value *True, Value *False, const Twine &Name="", Instruction *MDFrom=nullptr)
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
void SetCurrentDebugLocation(DebugLoc L)
Set location information used by debugging information.
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
LLVM_ABI Value * CreateElementCount(Type *Ty, ElementCount EC)
Create an expression which evaluates to the number of elements in EC at runtime.
LLVM_ABI bool isCommutative() const LLVM_READONLY
Return true if the instruction is commutative:
A wrapper class for inspecting calls to intrinsic functions.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Represents a single loop in the control flow graph.
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
bool isLegalStridedLoadStore(EVT DataType, Align Alignment) const
Return true if a stride load store of the given result type and alignment is legal.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
EVT getValueType(const DataLayout &DL, Type *Ty, bool AllowUnknown=false) const
Return the EVT corresponding to this LLVM type.
bool isTypeLegal(EVT VT) const
Return true if the target has native support for the specified value type.
Target-Independent Code Generator Pass Configuration Options.
bool isVectorTy() const
True if this is an instance of VectorType.
LLVM_ABI unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
void setOperand(unsigned i, Value *Val)
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
LLVM_ABI void takeName(Value *V)
Transfer the name from V to this value.
constexpr ScalarTy getFixedValue() const
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
TypeSize getSequentialElementStride(const DataLayout &DL) const
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
LLVM_ABI Value * getSplatValue(const Value *V)
Get splat value if the input is a splat vector or return nullptr.
FunctionPass * createRISCVGatherScatterLoweringPass()
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
generic_gep_type_iterator<> gep_type_iterator
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
gep_type_iterator gep_type_begin(const User *GEP)
LLVM_ABI bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
LLVM_ABI Constant * ConstantFoldCastInstruction(unsigned opcode, Constant *V, Type *DestTy)