33#include "llvm/IR/IntrinsicsX86.h"
46#define DEBUG_TYPE "x86-lower-amx-intrinsics"
51 return FVT->getNumElements() == 256 &&
52 FVT->getElementType()->isIntegerTy(32);
59 cl::desc(
"X86: enable AMX scalarizition."));
62class X86LowerAMXIntrinsics {
67 : Func(
F), DTU(DomTU), LI(LoopI) {}
73 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit,
Value *Bound,
74 Value *Step, StringRef Name, IRBuilderBase &
B,
76 template <
bool IsTileLoad>
77 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
80 template <Intrinsic::ID IntrID>
81 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
82 IntrID == Intrinsic::x86_tdpbsud_internal ||
83 IntrID == Intrinsic::x86_tdpbusd_internal ||
84 IntrID == Intrinsic::x86_tdpbuud_internal ||
85 IntrID == Intrinsic::x86_tdpbf16ps_internal,
87 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &
B,
90 template <
bool IsTileLoad>
91 bool lowerTileLoadStore(Instruction *TileLoadStore);
92 template <Intrinsic::ID IntrID>
93 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
94 IntrID == Intrinsic::x86_tdpbsud_internal ||
95 IntrID == Intrinsic::x86_tdpbusd_internal ||
96 IntrID == Intrinsic::x86_tdpbuud_internal ||
97 IntrID == Intrinsic::x86_tdpbf16ps_internal,
99 lowerTileDP(Instruction *TileDP);
100 bool lowerTileZero(Instruction *TileZero);
116 Type *I16Ty = Type::getInt16Ty(Ctx);
120 PHINode::Create(I16Ty, 2, Name +
".iv", Header->getTerminator()->getIterator());
121 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
123 B.SetInsertPoint(Latch);
124 Value *Inc =
B.CreateAdd(
IV, Step, Name +
".step");
125 Value *
Cond =
B.CreateICmpNE(Inc, Bound, Name +
".cond");
127 IV->addIncoming(Inc, Latch);
133 {DominatorTree::Delete, Preheader, Tmp},
134 {DominatorTree::Insert, Header, Body},
135 {DominatorTree::Insert, Body, Latch},
136 {DominatorTree::Insert, Latch, Header},
137 {DominatorTree::Insert, Latch,
Exit},
138 {DominatorTree::Insert, Preheader, Header},
141 L->addBasicBlockToLoop(Header, *LI);
142 L->addBasicBlockToLoop(Body, *LI);
143 L->addBasicBlockToLoop(Latch, *LI);
148template <
bool IsTileLoad>
149Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
150 BasicBlock *Start, BasicBlock *End, IRBuilderBase &
B,
Value *Row,
152 std::string IntrinName = IsTileLoad ?
"tileload" :
"tilestore";
153 Loop *RowLoop =
nullptr;
154 Loop *ColLoop =
nullptr;
160 ParentL->addChildLoop(RowLoop);
165 BasicBlock *RowBody = createLoop(Start, End, Row,
B.getInt16(1),
166 IntrinName +
".scalarize.rows",
B, RowLoop);
169 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
170 IntrinName +
".scalarize.cols",
B, ColLoop);
177 Type *EltTy =
B.getInt32Ty();
184 Value *CurrentRowZExt =
B.CreateZExt(CurrentRow, Stride->
getType());
185 Value *CurrentColZExt =
B.CreateZExt(CurrentCol, Stride->
getType());
187 B.CreateAdd(
B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
189 Value *Idx =
B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
196 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.phi.row");
203 PHINode *VecPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.phi");
212 Value *Elt =
B.CreateLoad(EltTy, EltPtr);
213 Value *ResVec =
B.CreateInsertElement(VecPhi, Elt, Idx);
220 Value *Vec = BitCast->getOperand(0);
228 Value *Elt =
B.CreateExtractElement(Vec, Idx);
230 B.CreateStore(Elt, EltPtr);
235template <Intrinsic::ID IntrID>
236std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
237 IntrID == Intrinsic::x86_tdpbsud_internal ||
238 IntrID == Intrinsic::x86_tdpbusd_internal ||
239 IntrID == Intrinsic::x86_tdpbuud_internal ||
240 IntrID == Intrinsic::x86_tdpbf16ps_internal,
242X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
243 IRBuilderBase &
B,
Value *Row,
246 std::string IntrinName;
248 case Intrinsic::x86_tdpbssd_internal:
249 IntrinName =
"tiledpbssd";
251 case Intrinsic::x86_tdpbsud_internal:
252 IntrinName =
"tiledpbsud";
254 case Intrinsic::x86_tdpbusd_internal:
255 IntrinName =
"tiledpbusd";
257 case Intrinsic::x86_tdpbuud_internal:
258 IntrinName =
"tiledpbuud";
260 case Intrinsic::x86_tdpbf16ps_internal:
261 IntrinName =
"tiledpbf16ps";
264 Loop *RowLoop =
nullptr;
265 Loop *ColLoop =
nullptr;
266 Loop *InnerLoop =
nullptr;
274 ParentL->addChildLoop(RowLoop);
279 BasicBlock *RowBody = createLoop(Start, End, Row,
B.getInt16(1),
280 IntrinName +
".scalarize.rows",
B, RowLoop);
283 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col,
B.getInt16(1),
284 IntrinName +
".scalarize.cols",
B, ColLoop);
290 createLoop(ColBody, ColLoopLatch, K,
B.getInt16(1),
291 IntrinName +
".scalarize.inner",
B, InnerLoop);
299 Value *CurrentInner = &*InnerLoopHeader->
begin();
303 Value *VecC = BitCastAcc->getOperand(0);
309 Value *VecA = BitCastLHS->getOperand(0);
312 Value *VecB = BitCastRHS->getOperand(0);
322 PHINode *VecCPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.row");
325 PHINode *VecDPhiRowLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.row");
339 PHINode *VecCPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.c.phi.col");
340 VecCPhiColLoop->
addIncoming(VecCPhiRowLoop, RowBody);
341 PHINode *VecDPhiColLoop =
B.CreatePHI(V256I32Ty, 2,
"vec.d.phi.col");
342 VecDPhiColLoop->
addIncoming(VecDPhiRowLoop, RowBody);
344 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentCol);
352 PHINode *VecCPhi =
B.CreatePHI(V256I32Ty, 2,
"vec.c.inner.phi");
357 B.CreateAdd(
B.CreateMul(CurrentRow,
B.getInt16(16)), CurrentInner);
359 B.CreateAdd(
B.CreateMul(CurrentInner,
B.getInt16(16)), CurrentCol);
360 Value *NewVecC =
nullptr;
362 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
379 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
380 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
381 Value *SubVecA =
B.CreateBitCast(EltA, V4I8Ty);
382 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
383 Value *SubVecB =
B.CreateBitCast(EltB, V4I8Ty);
384 Value *SEXTSubVecB =
nullptr;
385 Value *SEXTSubVecA =
nullptr;
387 case Intrinsic::x86_tdpbssd_internal:
388 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
389 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
391 case Intrinsic::x86_tdpbsud_internal:
392 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
393 SEXTSubVecA =
B.CreateSExt(SubVecA, V4I32Ty);
395 case Intrinsic::x86_tdpbusd_internal:
396 SEXTSubVecB =
B.CreateSExt(SubVecB, V4I32Ty);
397 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
399 case Intrinsic::x86_tdpbuud_internal:
400 SEXTSubVecB =
B.CreateZExt(SubVecB, V4I32Ty);
401 SEXTSubVecA =
B.CreateZExt(SubVecA, V4I32Ty);
406 Value *SubVecR =
B.CreateAddReduce(
B.CreateMul(SEXTSubVecA, SEXTSubVecB));
407 Value *ResElt =
B.CreateAdd(EltC, SubVecR);
408 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
434 Value *EltC =
B.CreateExtractElement(VecCPhi, IdxC);
435 Value *EltCF32 =
B.CreateBitCast(EltC,
B.getFloatTy());
436 Value *EltA =
B.CreateExtractElement(VecA, IdxA);
437 Value *SubVecA =
B.CreateBitCast(EltA, V2I16Ty);
438 Value *EltB =
B.CreateExtractElement(VecB, IdxB);
439 Value *SubVecB =
B.CreateBitCast(EltB, V2I16Ty);
441 int ShuffleMask[4] = {2, 0, 3, 1};
442 auto ShuffleArray =
ArrayRef(ShuffleMask);
443 Value *AV2F32 =
B.CreateBitCast(
444 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
445 Value *BV2F32 =
B.CreateBitCast(
446 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
447 Value *SubVecR =
B.CreateFAddReduce(EltCF32,
B.CreateFMul(AV2F32, BV2F32));
448 Value *ResElt =
B.CreateBitCast(SubVecR,
B.getInt32Ty());
449 NewVecC =
B.CreateInsertElement(VecCPhi, ResElt, IdxC);
457 Value *NewEltC =
B.CreateExtractElement(NewVecC, IdxC);
458 Value *NewVecD =
B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
462 VecCPhiColLoop->
addIncoming(NewVecC, ColLoopLatch);
464 VecDPhiColLoop->
addIncoming(NewVecD, ColLoopLatch);
469template <Intrinsic::ID IntrID>
470std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
471 IntrID == Intrinsic::x86_tdpbsud_internal ||
472 IntrID == Intrinsic::x86_tdpbusd_internal ||
473 IntrID == Intrinsic::x86_tdpbuud_internal ||
474 IntrID == Intrinsic::x86_tdpbf16ps_internal,
476X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
482 PreBuilder.SetInsertPoint(TileDP);
486 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
487 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
492 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
498 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
504 I->replaceAllUsesWith(ResVec);
505 I->eraseFromParent();
513template <
bool IsTileLoad>
514bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
527 PreBuilder.SetInsertPoint(TileLoadStore);
528 Value *NDWord = PreBuilder.CreateLShr(
N, PreBuilder.getInt16(2));
529 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
534 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
535 Start, End, Builder, M, NDWord,
Ptr, StrideDWord,
536 IsTileLoad ?
nullptr : Tile);
542 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
548 I->replaceAllUsesWith(ResVec);
549 I->eraseFromParent();
558bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
566 I->replaceAllUsesWith(VecZero);
567 I->eraseFromParent();
574bool X86LowerAMXIntrinsics::visit() {
580 switch (Inst->getIntrinsicID()) {
581 case Intrinsic::x86_tdpbssd_internal:
582 case Intrinsic::x86_tdpbsud_internal:
583 case Intrinsic::x86_tdpbusd_internal:
584 case Intrinsic::x86_tdpbuud_internal:
585 case Intrinsic::x86_tileloadd64_internal:
586 case Intrinsic::x86_tilestored64_internal:
587 case Intrinsic::x86_tilezero_internal:
588 case Intrinsic::x86_tdpbf16ps_internal:
598 for (
auto *Inst : WorkList) {
599 switch (Inst->getIntrinsicID()) {
600 case Intrinsic::x86_tdpbssd_internal:
601 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) ||
C;
603 case Intrinsic::x86_tdpbsud_internal:
604 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) ||
C;
606 case Intrinsic::x86_tdpbusd_internal:
607 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) ||
C;
609 case Intrinsic::x86_tdpbuud_internal:
610 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) ||
C;
612 case Intrinsic::x86_tdpbf16ps_internal:
613 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) ||
C;
615 case Intrinsic::x86_tileloadd64_internal:
616 C = lowerTileLoadStore<true>(Inst) ||
C;
618 case Intrinsic::x86_tilestored64_internal:
619 C = lowerTileLoadStore<false>(Inst) ||
C;
621 case Intrinsic::x86_tilezero_internal:
622 C = lowerTileZero(Inst) ||
C;
633bool shouldRunLowerAMXIntrinsics(
const Function &
F,
const TargetMachine *TM) {
635 TM->getOptLevel() == CodeGenOptLevel::None);
638bool runLowerAMXIntrinsics(Function &
F, DominatorTree *DT, LoopInfo *LI) {
639 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
641 X86LowerAMXIntrinsics LAT(
F, DTU, LI);
648 if (!shouldRunLowerAMXIntrinsics(
F, TM))
653 bool Changed = runLowerAMXIntrinsics(
F, &DT, &LI);
664class X86LowerAMXIntrinsicsLegacyPass :
public FunctionPass {
672 if (!shouldRunLowerAMXIntrinsics(
F, TM))
675 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
676 auto *DT = DTWP ? &DTWP->getDomTree() :
nullptr;
677 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
678 auto *LI = LIWP ? &LIWP->getLoopInfo() :
nullptr;
679 return runLowerAMXIntrinsics(
F, DT, LI);
681 StringRef getPassName()
const override {
return "Lower AMX intrinsics"; }
683 void getAnalysisUsage(AnalysisUsage &AU)
const override {
691static const char PassName[] =
"Lower AMX intrinsics";
692char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
700 return new X86LowerAMXIntrinsicsLegacyPass();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool runOnFunction(Function &F, bool PostInlining)
This header defines various interfaces for pass management in LLVM.
uint64_t IntrinsicInst * II
FunctionAnalysisManager FAM
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
const SmallVectorImpl< MachineOperand > & Cond
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
Target-Independent Code Generator Pass Configuration Options pass.
static cl::opt< bool > X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, cl::desc("X86: enable AMX scalarizition."))
static bool isV256I32Ty(Type *Ty)
static const char PassName[]
static const uint32_t IV[8]
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
const Function * getParent() const
Return the enclosing method, or null if none.
LLVM_ABI InstListType::const_iterator getFirstNonPHIIt() const
Returns an iterator to the first instruction in this block that is not a PHINode instruction.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
LLVM_ABI const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
InstListType::iterator iterator
Instruction iterators...
LLVM_ABI LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
void setSuccessor(unsigned idx, BasicBlock *NewSucc)
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
static LLVM_ABI FixedVectorType * get(Type *ElementType, unsigned NumElts)
FunctionPass class - This class is used to implement most global optimizations.
void applyUpdatesPermissive(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
Common base class shared among various IRBuilders.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Analysis pass that exposes the LoopInfo for a function.
void addChildLoop(LoopT *NewChild)
Add the specified loop to be a child of this loop.
void addTopLevelLoop(LoopT *New)
This adds the specified loop to the collection of top-level loops.
LoopT * AllocateLoop(ArgsTy &&...Args)
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
Represents a single loop in the control flow graph.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
void push_back(const T &Elt)
StringRef - Represent a constant reference to a string, i.e.
Primary interface to the complete machine description for the target machine.
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
iterator_range< use_iterator > uses()
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM)
const ParentTy * getParent() const
Pass manager infrastructure for declaring and invalidating analyses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
@ BasicBlock
Various leaf nodes.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
CastOperator_match< OpTy, Instruction::BitCast > m_BitCast(const OpTy &Op)
Matches BitCast.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
initializer< Ty > init(const Ty &Val)
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
FunctionPass * createX86LowerAMXIntrinsicsLegacyPass()
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.