Index: llvm/lib/Transforms/Scalar/LICM.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LICM.cpp +++ llvm/lib/Transforms/Scalar/LICM.cpp @@ -522,6 +522,28 @@ return Changed; } +static bool canKeepMetadata(const Instruction &I, const DominatorTree *DT, + const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo) { + // The metadata is valid in the loop preheader if we are guaranteed to + // execute I if we entered the loop. + return isGuaranteedToExecute(I, DT, CurLoop, SafetyInfo); +} + +static bool +isUnconditionalInvariantGroupLoad(LoadInst *LI, const DominatorTree *DT, + const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo) { + if (!LI->getMetadata(LLVMContext::MD_invariant_group)) + return false; + + // For now we only want to hoist invariant.group loads if we can keep + // the metadata. This is because we don't know yet if it's better to hoist it + // and loose metadata, or to keep the metadata counting that we will be able + // to merge this load with another outside the loop. + return canKeepMetadata(*LI, DT, CurLoop, SafetyInfo); +} + // Return true if LI is invariant within scope of the loop. LI is invariant if // CurLoop is dominated by an invariant.start representing the same memory // location and size as the memory location LI loads from, and also the @@ -597,6 +619,9 @@ if (LI->isAtomic() && SinkingToLoopBody) return false; // Don't sink unordered atomic loads to loop body. + if (isUnconditionalInvariantGroupLoad(LI, DT, CurLoop, SafetyInfo)) + return true; + // This checks for an invariant.start dominating the load. if (isLoadInvariantInLoop(LI, DT, CurLoop)) return true; @@ -1041,15 +1066,13 @@ << ore::NV("Inst", &I); }); - // Metadata can be dependent on conditions we are hoisting above. - // Conservatively strip all metadata on the instruction unless we were - // guaranteed to execute I if we entered the loop, in which case the metadata - // is valid in the loop preheader. + // Metadata can be dependent on conditions we are hoisting above. Except when + // we can prove the metadata independent of any such conditions, strip it. if (I.hasMetadataOtherThanDebugLoc() && // The check on hasMetadataOtherThanDebugLoc is to prevent us from burning - // time in isGuaranteedToExecute if we don't actually have anything to + // time in canKeepMetadata if we don't actually have anything to // drop. It is a compile time optimization, not required for correctness. - !isGuaranteedToExecute(I, DT, CurLoop, SafetyInfo)) + !canKeepMetadata(I, DT, CurLoop, SafetyInfo)) I.dropUnknownNonDebugMetadata(); // Move the new node to the Preheader, before its terminator. Index: llvm/test/Transforms/LICM/hoist-invariant-group-load.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LICM/hoist-invariant-group-load.ll @@ -0,0 +1,128 @@ +; RUN: opt -licm -disable-basicaa -S < %s | FileCheck %s + +%struct.A = type { i32 (...)** } + +; CHECK-LABEL: @hoist( +define void @hoist(%struct.A* %arg) { +entry: + br i1 undef, label %while.end, label %while.body.lr.ph + +; CHECK: while.body.lr.ph: +while.body.lr.ph: ; preds = %entry +; CHECK: [[VTABLE:%.*]] = load void (%struct.A*)**, void (%struct.A*)*** [[B:%.*]], align 8, !invariant.group +; CHECK-NEXT: [[TMP:%.*]] = load void (%struct.A*)*, void (%struct.A*)** [[VTABLE]], align 8, !invariant.load +; CHECK-NEXT: br label [[WHILE_BODY:%.*]] + %b = bitcast %struct.A* %arg to void (%struct.A*)*** + br label %while.body + +while.body: ; preds = %while.body, %while.body.lr.ph +; CHECK: while.body: + + %vtable = load void (%struct.A*)**, void (%struct.A*)*** %b, align 8, !invariant.group !1 + %tmp = load void (%struct.A*)*, void (%struct.A*)** %vtable, align 8, !invariant.load !1 + tail call void %tmp(%struct.A* %arg) + %call = tail call i32 @bar() + %tobool = icmp eq i32 %call, 0 + br i1 %tobool, label %while.end.loopexit, label %while.body + +while.end.loopexit: ; preds = %while.body + br label %while.end + +while.end: ; preds = %while.end.loopexit, %entry + ret void +} + +; CHECK-LABEL: @hoist2( +define void @hoist2(i8** %arg) { +entry: + %call1 = tail call i32 @bar() + %tobool2 = icmp eq i32 %call1, 0 + br i1 %tobool2, label %while.end, label %while.body.lr.ph + +while.body.lr.ph: ; preds = %entry +; CHECK: while.body.lr.ph: +; CHECK-NEXT: [[X:%.*]] = load i8*, i8** [[ARG:%.*]], align 8, !invariant.group +; CHECK-NEXT: br label [[WHILE_BODY:%.*]] + br label %while.body + +; CHECK: while.body: +while.body: ; preds = %while.body, %while.body.lr.ph + %x = load i8*, i8** %arg, align 8, !invariant.group !1 + call void @foo(i8* %x) + %call = tail call i32 @bar() + %tobool = icmp eq i32 %call, 0 + br i1 %tobool, label %while.end.loopexit, label %while.body + +while.end.loopexit: ; preds = %while.body + br label %while.end + +while.end: ; preds = %while.end.loopexit, %entry + ret void +} + +declare void @foo(i8*) + +declare i32 @bar() + +; CHECK-LABEL: @dontHoist( +define void @dontHoist(%struct.A** %a) { + +entry: + %call4 = tail call i32 @bar() + %cmp5 = icmp sgt i32 %call4, 0 + br i1 %cmp5, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %for.body + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +; CHECK: for.body: +for.body: +; CHECK: [[VTABLE:%.*]] = load void (%struct.A*)**, void (%struct.A*)*** {{.*}}, align 8, !dereferenceable !{{.*}}, !invariant.group +; CHECK-NEXT: [[TMP2:%.*]] = load void (%struct.A*)*, void (%struct.A*)** [[VTABLE]], align 8, !invariant.load + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds %struct.A*, %struct.A** %a, i64 %indvars.iv + %tmp = load %struct.A*, %struct.A** %arrayidx, align 8 + %tmp1 = bitcast %struct.A* %tmp to void (%struct.A*)*** + %vtable = load void (%struct.A*)**, void (%struct.A*)*** %tmp1, align 8, !dereferenceable !0, !invariant.group !1 + %tmp2 = load void (%struct.A*)*, void (%struct.A*)** %vtable, align 8, !invariant.load !1 + tail call void %tmp2(%struct.A* %tmp) + %indvars.iv.next = add nuw i64 %indvars.iv, 1 + %call = tail call i32 @bar() + %tmp3 = sext i32 %call to i64 + %cmp = icmp slt i64 %indvars.iv.next, %tmp3 + br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit +} + +; CHECK-LABEL: @donthoist2( +define void @donthoist2(i8** dereferenceable(8) %arg) { +entry: + br i1 undef, label %while.end, label %while.body.lr.ph + +while.body.lr.ph: ; preds = %entry + br label %while.body + +; CHECK: while.body: +while.body: ; preds = %while.body, %while.body.lr.ph +; CHECK: [[X:%.*]] = load i8*, i8** [[ARG:%.*]], align 8, !invariant.group + %call = tail call i32 @bar() + %x = load i8*, i8** %arg, align 8, !invariant.group !1 + call void @foo(i8* %x) + + %tobool = icmp eq i32 %call, 0 + br i1 %tobool, label %while.end.loopexit, label %while.body + +while.end.loopexit: ; preds = %while.body + br label %while.end + +while.end: ; preds = %while.end.loopexit, %entry + ret void +} + +!0 = !{i64 8} +!1 = !{}