diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -95,8 +95,9 @@ OptionalParameter<"IntegerAttr">:$count, OptionalParameter<"BoolAttr">:$runtimeDisable, OptionalParameter<"BoolAttr">:$full, - OptionalParameter<"LoopAnnotationAttr">:$followup, - OptionalParameter<"LoopAnnotationAttr">:$followupRemainder + OptionalParameter<"LoopAnnotationAttr">:$followupUnrolled, + OptionalParameter<"LoopAnnotationAttr">:$followupRemainder, + OptionalParameter<"LoopAnnotationAttr">:$followupAll ); let assemblyFormat = "`<` struct(params) `>`"; @@ -186,6 +187,7 @@ OptionalParameter<"LoopDistributeAttr">:$distribute, OptionalParameter<"LoopPipelineAttr">:$pipeline, OptionalParameter<"BoolAttr">:$mustProgress, + OptionalParameter<"BoolAttr">:$isVectorized, OptionalArrayRefParameter<"SymbolRefAttr">:$parallelAccesses ); diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationImporter.cpp @@ -33,6 +33,7 @@ /// specified name, or failure, if the node is ill-formatted. FailureOr lookupUnitNode(StringRef name); FailureOr lookupBoolNode(StringRef name, bool negated = false); + FailureOr lookupIntNodeAsBoolAttr(StringRef name); FailureOr lookupIntNode(StringRef name); FailureOr lookupMDNode(StringRef name); FailureOr> lookupMDNodes(StringRef name); @@ -155,6 +156,27 @@ return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated); } +FailureOr +LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) { + const llvm::MDNode *property = lookupAndEraseProperty(name); + if (!property) + return BoolAttr(nullptr); + + auto emitNodeWarning = [&]() { + return emitWarning(loc) + << "expected metadata node " << name << " to hold an integer value"; + }; + + if (property->getNumOperands() != 2) + return emitNodeWarning(); + llvm::ConstantInt *val = + llvm::mdconst::dyn_extract(property->getOperand(1)); + if (!val || val->getBitWidth() != 32) + return emitNodeWarning(); + + return BoolAttr::get(ctx, val->getValue().getLimitedValue(1)); +} + FailureOr LoopMetadataConversion::lookupIntNode(StringRef name) { const llvm::MDNode *property = lookupAndEraseProperty(name); if (!property) @@ -287,13 +309,16 @@ FailureOr runtimeDisable = lookupUnitNode("llvm.loop.unroll.runtime.disable"); FailureOr full = lookupUnitNode("llvm.loop.unroll.full"); - FailureOr followup = - lookupFollowupNode("llvm.loop.unroll.followup"); + FailureOr followupUnrolled = + lookupFollowupNode("llvm.loop.unroll.followup_unrolled"); FailureOr followupRemainder = lookupFollowupNode("llvm.loop.unroll.followup_remainder"); + FailureOr followupAll = + lookupFollowupNode("llvm.loop.unroll.followup_all"); return createIfNonNull(ctx, disable, count, runtimeDisable, - full, followup, followupRemainder); + full, followupUnrolled, + followupRemainder, followupAll); } FailureOr @@ -379,6 +404,8 @@ FailureOr distributeAttr = convertDistributeAttr(); FailureOr pipelineAttr = convertPipelineAttr(); FailureOr mustProgress = lookupUnitNode("llvm.loop.mustprogress"); + FailureOr isVectorized = + lookupIntNodeAsBoolAttr("llvm.loop.isvectorized"); FailureOr> parallelAccesses = convertParallelAccesses(); @@ -392,7 +419,7 @@ return createIfNonNull( ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr, unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, mustProgress, - parallelAccesses); + isVectorized, parallelAccesses); } LoopAnnotationAttr diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -29,6 +29,7 @@ /// Conversion functions for different payload attribute kinds. void addUnitNode(StringRef name); void addUnitNode(StringRef name, BoolAttr attr); + void addI32NodeWithVal(StringRef name, uint32_t val); void convertBoolNode(StringRef name, BoolAttr attr, bool negated = false); void convertI32Node(StringRef name, IntegerAttr attr); void convertFollowupNode(StringRef name, LoopAnnotationAttr attr); @@ -61,6 +62,14 @@ addUnitNode(name); } +void LoopAnnotationConversion::addI32NodeWithVal(StringRef name, uint32_t val) { + llvm::Constant *cstValue = llvm::ConstantInt::get( + llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false); + metadataNodes.push_back( + llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), + llvm::ConstantAsMetadata::get(cstValue)})); +} + void LoopAnnotationConversion::convertBoolNode(StringRef name, BoolAttr attr, bool negated) { if (!attr) @@ -76,12 +85,7 @@ IntegerAttr attr) { if (!attr) return; - uint32_t val = attr.getInt(); - llvm::Constant *cstValue = llvm::ConstantInt::get( - llvm::IntegerType::get(ctx, /*NumBits=*/32), val, /*isSigned=*/false); - metadataNodes.push_back( - llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), - llvm::ConstantAsMetadata::get(cstValue)})); + addI32NodeWithVal(name, attr.getInt()); } void LoopAnnotationConversion::convertFollowupNode(StringRef name, @@ -122,9 +126,12 @@ convertBoolNode("llvm.loop.unroll.runtime.disable", options.getRuntimeDisable()); addUnitNode("llvm.loop.unroll.full", options.getFull()); - convertFollowupNode("llvm.loop.unroll.followup", options.getFollowup()); + convertFollowupNode("llvm.loop.unroll.followup_unrolled", + options.getFollowupUnrolled()); convertFollowupNode("llvm.loop.unroll.followup_remainder", options.getFollowupRemainder()); + convertFollowupNode("llvm.loop.unroll.followup_all", + options.getFollowupAll()); } void LoopAnnotationConversion::convertLoopOptions( @@ -177,6 +184,9 @@ addUnitNode("llvm.loop.disable_nonforced", attr.getDisableNonforced()); addUnitNode("llvm.loop.mustprogress", attr.getMustProgress()); + // "isvectorized" is encoded as an i32 value. + if (BoolAttr isVectorized = attr.getIsVectorized()) + addI32NodeWithVal("llvm.loop.isvectorized", isVectorized.getValue()); if (auto options = attr.getVectorize()) convertLoopOptions(options); diff --git a/mlir/test/Dialect/LLVMIR/loop-metadata.mlir b/mlir/test/Dialect/LLVMIR/loop-metadata.mlir --- a/mlir/test/Dialect/LLVMIR/loop-metadata.mlir +++ b/mlir/test/Dialect/LLVMIR/loop-metadata.mlir @@ -12,10 +12,10 @@ // CHECK-DAG: #[[INTERLEAVE:.*]] = #llvm.loop_interleave #interleave = #llvm.loop_interleave -// CHECK-DAG: #[[UNROLL:.*]] = #llvm.loop_unroll +// CHECK-DAG: #[[UNROLL:.*]] = #llvm.loop_unroll #unroll = #llvm.loop_unroll< disable = true, count = 32 : i32, runtimeDisable = true, full = false, - followup = #followup, followupRemainder = #followup + followupUnrolled = #followup, followupRemainder = #followup, followupAll = #followup > // CHECK-DAG: #[[UNROLL_AND_JAM:.*]] = #llvm.loop_unroll_and_jam @@ -44,6 +44,7 @@ // CHECK-DAG: licm = #[[LICM]] // CHECK-DAG: distribute = #[[DISTRIBUTE]] // CHECK-DAG: pipeline = #[[PIPELINE]] +// CHECK-DAG: isVectorized = false // CHECK-DAG: parallelAccesses = @metadata::@group1, @metadata::@group2> #loopMD = #llvm.loop_annotation // CHECK: llvm.func @loop_annotation diff --git a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll --- a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll @@ -28,7 +28,7 @@ ; // ----- -; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation +; CHECK: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation ; CHECK-LABEL: @simple define void @simple(i64 %n, ptr %A) { @@ -39,9 +39,10 @@ ret void } -!1 = distinct !{!1, !2, !3} +!1 = distinct !{!1, !2, !3, !4} !2 = !{!"llvm.loop.disable_nonforced"} !3 = !{!"llvm.loop.mustprogress"} +!4 = !{!"llvm.loop.isvectorized", i32 1} ; // ----- @@ -90,7 +91,7 @@ ; // ----- ; CHECK-DAG: #[[FOLLOWUP:.*]] = #llvm.loop_annotation -; CHECK-DAG: #[[UNROLL_ATTR:.*]] = #llvm.loop_unroll +; CHECK-DAG: #[[UNROLL_ATTR:.*]] = #llvm.loop_unroll ; CHECK-DAG: #[[$ANNOT_ATTR:.*]] = #llvm.loop_annotation ; CHECK-LABEL: @unroll @@ -102,16 +103,17 @@ ret void } -!1 = distinct !{!1, !2, !3, !4, !5, !6, !7} +!1 = distinct !{!1, !2, !3, !4, !5, !6, !7, !8} !2 = !{!"llvm.loop.unroll.enable"} !3 = !{!"llvm.loop.unroll.count", i32 16} !4 = !{!"llvm.loop.unroll.runtime.disable"} !5 = !{!"llvm.loop.unroll.full"} -!6 = !{!"llvm.loop.unroll.followup", !8} -!7 = !{!"llvm.loop.unroll.followup_remainder", !8} +!6 = !{!"llvm.loop.unroll.followup_unrolled", !9} +!7 = !{!"llvm.loop.unroll.followup_remainder", !9} +!8 = !{!"llvm.loop.unroll.followup_all", !9} -!8 = distinct !{!8, !9} -!9 = !{!"llvm.loop.disable_nonforced"} +!9 = distinct !{!9, !10} +!10 = !{!"llvm.loop.disable_nonforced"} ; // ----- diff --git a/mlir/test/Target/LLVMIR/loop-metadata.mlir b/mlir/test/Target/LLVMIR/loop-metadata.mlir --- a/mlir/test/Target/LLVMIR/loop-metadata.mlir +++ b/mlir/test/Target/LLVMIR/loop-metadata.mlir @@ -26,6 +26,18 @@ // ----- +// CHECK-LABEL: @isvectorized +llvm.func @isvectorized() { + // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] + llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation} +^bb1: + llvm.return +} + +// CHECK: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}} +// CHECK-DAG: ![[VEC_NODE0:[0-9]+]] = !{!"llvm.loop.isvectorized", i32 1} + +// ----- #followup = #llvm.loop_annotation @@ -73,7 +85,7 @@ // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation + followupUnrolled = #followup, followupRemainder = #followup, followupAll = #followup> >} ^bb1: llvm.return @@ -81,12 +93,13 @@ // CHECK-DAG: ![[NON_FORCED:[0-9]+]] = !{!"llvm.loop.disable_nonforced"} // CHECK-DAG: ![[FOLLOWUP:[0-9]+]] = distinct !{![[FOLLOWUP]], ![[NON_FORCED]]} -// CHECK-DAG: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}} +// CHECK-DAG: ![[LOOP_NODE]] = distinct !{![[LOOP_NODE]], !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}} // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.disable"} // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.count", i32 64} // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.runtime.disable", i1 false} -// CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup", ![[FOLLOWUP]]} +// CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup_unrolled", ![[FOLLOWUP]]} // CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup_remainder", ![[FOLLOWUP]]} +// CHECK-DAG: !{{[0-9]+}} = !{!"llvm.loop.unroll.followup_all", ![[FOLLOWUP]]} // -----