diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -37,7 +37,6 @@ static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; } static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; } static StringRef getAliasScopesAttrName() { return "alias_scopes"; } - static StringRef getLoopAttrName() { return "llvm.loop"; } static StringRef getAccessGroupsAttrName() { return "access_groups"; } /// Names of llvm parameter attributes. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -778,7 +778,10 @@ // Terminators. def LLVM_BrOp : LLVM_TerminatorOp<"br", [DeclareOpInterfaceMethods, Pure]> { - let arguments = (ins Variadic:$destOperands); + let arguments = (ins + Variadic:$destOperands, + OptionalAttr:$loop_annotation + ); let successors = (successor AnySuccessor:$dest); let assemblyFormat = [{ $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict @@ -787,8 +790,12 @@ OpBuilder<(ins "Block *":$dest), [{ build($_builder, $_state, ValueRange(), dest); }]>, + OpBuilder<(ins "ValueRange":$operands, "Block *":$dest), [{ + build($_builder, $_state, operands, /*loop_annotation=*/{}, dest); + }]>, LLVM_TerminatorPassthroughOpBuilder ]; + let hasVerifier = 1; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, @@ -796,7 +803,8 @@ let arguments = (ins I1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights, + OptionalAttr:$loop_annotation); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = [{ $condition ( `weights` `(` $branch_weights^ `)` )? `,` @@ -809,23 +817,20 @@ OpBuilder<(ins "Value":$condition, "Block *":$trueDest, "ValueRange":$trueOperands, "Block *":$falseDest, "ValueRange":$falseOperands, - CArg<"std::optional>", "{}">:$weights), - [{ - ElementsAttr weightsAttr; - if (weights) { - weightsAttr = - $_builder.getI32VectorAttr({static_cast(weights->first), - static_cast(weights->second)}); - } - build($_builder, $_state, condition, trueOperands, falseOperands, weightsAttr, - trueDest, falseDest); - }]>, + CArg<"std::optional>", "{}">:$weights)>, OpBuilder<(ins "Value":$condition, "Block *":$trueDest, "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{ build($_builder, $_state, condition, trueDest, ValueRange(), falseDest, falseOperands); + }]>, + OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands, + "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), + [{ + build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights, + {}, trueDest, falseDest); }]>, LLVM_TerminatorPassthroughOpBuilder]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -264,11 +264,42 @@ // LLVM::BrOp //===----------------------------------------------------------------------===// +/// Check if the `loopAttr` references correct symbols. +static LogicalResult verifyLoopAnnotationAttr(LoopAnnotationAttr loopAttr, + Operation *op) { + if (!loopAttr) + return success(); + // If the `llvm.loop` attribute is present, enforce the following structure, + // which the module translation can assume. + ArrayRef parallelAccesses = loopAttr.getParallelAccesses(); + if (parallelAccesses.empty()) + return success(); + for (SymbolRefAttr accessGroupRef : parallelAccesses) { + StringAttr metadataName = accessGroupRef.getRootReference(); + auto metadataOp = SymbolTable::lookupNearestSymbolFrom( + op->getParentOp(), metadataName); + if (!metadataOp) + return op->emitOpError() << "expected '" << accessGroupRef + << "' to reference a metadata op"; + StringAttr accessGroupName = accessGroupRef.getLeafReference(); + Operation *accessGroupOp = + SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); + if (!accessGroupOp) + return op->emitOpError() << "expected '" << accessGroupRef + << "' to reference an access_group op"; + } + return success(); +} + SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return SuccessorOperands(getDestOperandsMutable()); } +LogicalResult BrOp::verify() { + return verifyLoopAnnotationAttr(getLoopAnnotationAttr(), *this); +} + //===----------------------------------------------------------------------===// // LLVM::CondBrOp //===----------------------------------------------------------------------===// @@ -279,6 +310,24 @@ : getFalseDestOperandsMutable()); } +LogicalResult CondBrOp::verify() { + return verifyLoopAnnotationAttr(getLoopAnnotationAttr(), *this); +} + +void CondBrOp::build(OpBuilder &builder, OperationState &result, + Value condition, Block *trueDest, ValueRange trueOperands, + Block *falseDest, ValueRange falseOperands, + std::optional> weights) { + ElementsAttr weightsAttr; + if (weights) + weightsAttr = + builder.getI32VectorAttr({static_cast(weights->first), + static_cast(weights->second)}); + + build(builder, result, condition, trueOperands, falseOperands, weightsAttr, + /*loop_annotation=*/{}, trueDest, falseDest); +} + //===----------------------------------------------------------------------===// // LLVM::SwitchOp //===----------------------------------------------------------------------===// @@ -2977,32 +3026,6 @@ /// Verify LLVM dialect attributes. LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - // If the `llvm.loop` attribute is present, enforce the following structure, - // which the module translation can assume. - if (attr.getName() == LLVMDialect::getLoopAttrName()) { - auto loopAttr = attr.getValue().dyn_cast(); - if (!loopAttr) - return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() - << "' to be a loop annotation attribute"; - ArrayRef parallelAccesses = loopAttr.getParallelAccesses(); - if (parallelAccesses.empty()) - return success(); - for (SymbolRefAttr accessGroupRef : parallelAccesses) { - StringAttr metadataName = accessGroupRef.getRootReference(); - auto metadataOp = SymbolTable::lookupNearestSymbolFrom( - op->getParentOp(), metadataName); - if (!metadataOp) - return op->emitOpError() << "expected '" << accessGroupRef - << "' to reference a metadata op"; - StringAttr accessGroupName = accessGroupRef.getLeafReference(); - Operation *accessGroupOp = - SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); - if (!accessGroupOp) - return op->emitOpError() << "expected '" << accessGroupRef - << "' to reference an access_group op"; - } - } - // If the data layout attribute is present, it must use the LLVM data layout // syntax. Try parsing it and report errors in case of failure. Users of this // attribute may assume it is well-formed and can pass it to the (asserting) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -203,8 +203,12 @@ if (!attr) return failure(); - op->setAttr(LLVMDialect::getLoopAttrName(), attr); - return success(); + return TypeSwitch(op) + .Case([&](auto branchOp) { + branchOp.setLoopAnnotationAttr(attr); + return success(); + }) + .Default([](auto) { return failure(); }); } namespace { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1244,8 +1244,10 @@ void ModuleTranslation::setLoopMetadata(Operation *op, llvm::Instruction *inst) { - auto attr = - op->getAttrOfType(LLVMDialect::getLoopAttrName()); + LoopAnnotationAttr attr = + TypeSwitch(op) + .Case( + [](auto branchOp) { return branchOp.getLoopAnnotationAttr(); }); if (!attr) return; llvm::MDNode *loopMD = loopAnnotationTranslation->translate(attr, op); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -840,21 +840,10 @@ // ----- -module { - llvm.func @loopOptions() { - // expected-error@below {{expected 'llvm.loop' to be a loop annotation attribute}} - llvm.br ^bb4 {llvm.loop = "test"} - ^bb4: - llvm.return - } -} - -// ----- - module { llvm.func @loopOptions() { // expected-error@below {{expected '@func1' to reference a metadata op}} - llvm.br ^bb4 {llvm.loop = #llvm.loop_annotation} + llvm.br ^bb4 {loop_annotation = #llvm.loop_annotation} ^bb4: llvm.return } @@ -868,7 +857,7 @@ module { llvm.func @loopOptions() { // expected-error@below {{expected '@metadata' to reference an access_group op}} - llvm.br ^bb4 {llvm.loop = #llvm.loop_annotation} + llvm.br ^bb4 {loop_annotation = #llvm.loop_annotation} ^bb4: llvm.return } diff --git a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll --- a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll @@ -32,7 +32,7 @@ ; CHECK-LABEL: @simple define void @simple(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -52,7 +52,7 @@ ; CHECK-LABEL: @vectorize define void @vectorize(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -78,7 +78,7 @@ ; CHECK-LABEL: @interleave define void @interleave(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -96,7 +96,7 @@ ; CHECK-LABEL: @unroll define void @unroll(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -122,7 +122,7 @@ ; CHECK-LABEL: @unroll_disable define void @unroll_disable(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -140,7 +140,7 @@ ; CHECK-LABEL: @unroll_and_jam define void @unroll_and_jam(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -166,7 +166,7 @@ ; CHECK-LABEL: @licm define void @licm(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -185,7 +185,7 @@ ; CHECK-LABEL: @distribute define void @distribute(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -209,7 +209,7 @@ ; CHECK-LABEL: @pipeline define void @pipeline(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -227,7 +227,7 @@ ; CHECK-LABEL: @peeled define void @peeled(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -244,7 +244,7 @@ ; CHECK-LABEL: @unswitched define void @unswitched(i64 %n, ptr %A) { entry: -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -264,7 +264,7 @@ define void @parallel_accesses(ptr %arg) { entry: %0 = load i32, ptr %arg, !llvm.access.group !0 -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void @@ -287,7 +287,7 @@ entry: %0 = load i32, ptr %arg, !llvm.access.group !0 %1 = load i32, ptr %arg, !llvm.access.group !3 -; CHECK: llvm.br ^{{.*}} {llvm.loop = #[[$ANNOT_ATTR]]} +; CHECK: llvm.br ^{{.*}} {loop_annotation = #[[$ANNOT_ATTR]]} br label %end, !llvm.loop !1 end: ret void diff --git a/mlir/test/Target/LLVMIR/loop-metadata.mlir b/mlir/test/Target/LLVMIR/loop-metadata.mlir --- a/mlir/test/Target/LLVMIR/loop-metadata.mlir +++ b/mlir/test/Target/LLVMIR/loop-metadata.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @disableNonForced llvm.func @disableNonForced() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation} ^bb1: llvm.return } @@ -16,7 +16,7 @@ // CHECK-LABEL: @mustprogress llvm.func @mustprogress() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation} ^bb1: llvm.return } @@ -29,7 +29,7 @@ // CHECK-LABEL: @isvectorized llvm.func @isvectorized() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation} ^bb1: llvm.return } @@ -44,7 +44,7 @@ // CHECK-LABEL: @vectorizeOptions llvm.func @vectorizeOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation >} @@ -68,7 +68,7 @@ // CHECK-LABEL: @interleaveOptions llvm.func @interleaveOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -83,7 +83,7 @@ // CHECK-LABEL: @unrollOptions llvm.func @unrollOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation >} @@ -106,7 +106,7 @@ // CHECK-LABEL: @unrollOptions2 llvm.func @unrollOptions2() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -122,7 +122,7 @@ // CHECK-LABEL: @unrollAndJamOptions llvm.func @unrollAndJamOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation >} @@ -146,7 +146,7 @@ // CHECK-LABEL: @licmOptions llvm.func @licmOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -159,7 +159,7 @@ // CHECK-LABEL: @licmOptions2 llvm.func @licmOptions2() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -174,7 +174,7 @@ // CHECK-LABEL: @distributeOptions llvm.func @distributeOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation >} @@ -196,7 +196,7 @@ // CHECK-LABEL: @pipelineOptions llvm.func @pipelineOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -210,7 +210,7 @@ // CHECK-LABEL: @peeledOptions llvm.func @peeledOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -223,7 +223,7 @@ // CHECK-LABEL: @unswitchOptions llvm.func @unswitchOptions() { // CHECK: br {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.br ^bb1 {llvm.loop = #llvm.loop_annotation>} + llvm.br ^bb1 {loop_annotation = #llvm.loop_annotation>} ^bb1: llvm.return } @@ -241,7 +241,7 @@ ^bb3(%1: i32): %2 = llvm.icmp "slt" %1, %arg1 : i32 // CHECK: br i1 {{.*}} !llvm.loop ![[LOOP_NODE:[0-9]+]] - llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = #llvm.loop_annotation< + llvm.cond_br %2, ^bb4, ^bb5 {loop_annotation = #llvm.loop_annotation< licm = , interleave = , unroll = , pipeline = , @@ -251,7 +251,7 @@ // CHECK: = load i32, ptr %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]] %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]] - llvm.br ^bb3(%3 : i32) {llvm.loop = #llvm.loop_annotation< + llvm.br ^bb3(%3 : i32) {loop_annotation = #llvm.loop_annotation< licm = , interleave = , unroll = , pipeline = ,