diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1979,6 +1979,9 @@ // Op attached regions have no arguments def NoRegionArguments : NativeOpTrait<"NoRegionArguments">; +// Operation can be only used in trivial SSACFG region (no block-based CFG) +def UsedOnlyInTrivialSSACFG : NativeOpTrait<"UsedOnlyInTrivialSSACFG">; + //===----------------------------------------------------------------------===// // OpInterface definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -277,6 +277,7 @@ LogicalResult verifyNoRegionArguments(Operation *op); LogicalResult verifyElementwise(Operation *op); LogicalResult verifyIsIsolatedFromAbove(Operation *op); +LogicalResult verifyUsedOnlyInTrivialSSACFG(Operation *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -1411,6 +1412,14 @@ } }; +template +struct UsedOnlyInTrivialSSACFG + : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyUsedOnlyInTrivialSSACFG(op); + } +}; + /// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` /// provide an easy way for scalar operations to conveniently generalize their /// behavior to vectors/tensors, and systematize conversion between these forms. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1156,6 +1156,38 @@ return success(); } +LogicalResult OpTrait::impl::verifyUsedOnlyInTrivialSSACFG(Operation *op) { + Region *parentRegion = op->getParentRegion(); + while (parentRegion) { + Operation *parentOp = parentRegion->getParentOp(); + + if (parentRegion->getBlocks().size() != 1) { + return op->emitOpError("Parent region ") + .append(parentOp->getLoc()) + .append(" of the operation is not a trivial SSACFG (it has multiple " + "blocks)"); + } + + if (RegionKindInterface iface = + dyn_cast_or_null(parentOp)) { + if (!iface.hasSSADominance(parentRegion->getRegionNumber())) { + return op->emitOpError("Parent region ") + .append(parentOp->getLoc()) + .append( + " of the operation is not a trivial SSACFG (is doesn't have " + "SSA dominance property)"); + } + } + + if (parentOp->hasTrait()) + break; + + parentRegion = parentOp->getParentRegion(); + } + + return success(); +} + bool OpTrait::hasElementwiseMappableTraits(Operation *op) { return op->hasTrait() && op->hasTrait() && op->hasTrait() && op->hasTrait(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -701,7 +701,7 @@ if (useInfo.number < entries.size() && entries[useInfo.number].value) { Value result = entries[useInfo.number].value; // Check that the type matches the other uses. - if (result.getType() == type) + if (!type || result.getType() == type) return maybeRecordUse(result); emitError(useInfo.loc, "use of value '") @@ -713,6 +713,13 @@ return nullptr; } + if (!type) { + emitError(useInfo.loc, "forward reference of value '") + .append(useInfo.name, "' requires explicit type specification") + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)); + return nullptr; + } + // Make sure we have enough slots for this. if (entries.size() <= useInfo.number) entries.resize(useInfo.number + 1); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1732,8 +1732,8 @@ } // Test various mixings of operand type formatting. -class FormatOperandBase - : TEST_Op<"format_operand_" # suffix # "_op"> { +class FormatOperandBase traits = []> + : TEST_Op<"format_operand_" # suffix # "_op", traits> { let arguments = (ins I64:$buildable, AnyMemRef:$operand); let assemblyFormat = fmt; } @@ -1753,6 +1753,12 @@ def FormatOperandEOp : FormatOperandBase<"e", [{ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict }]>; +def FormatOperandTrivialCFGAOp : FormatOperandBase<"trivial_cfg_a", [{ + operands attr-dict +}], [UsedOnlyInTrivialSSACFG]>; +def FormatOperandTrivialCFGBOp : FormatOperandBase<"trivial_cfg_b", [{ + $buildable `,` $operand attr-dict +}], [UsedOnlyInTrivialSSACFG]>; def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { let successors = (successor VariadicSuccessor:$targets); diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -148,6 +148,16 @@ // CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> test.format_operand_e_op %i64, %memref : i64, memref<1xf64> +func @inner_f(%arg0: i64, %arg1: memref<1xf64>) { + // CHECK: test.format_operand_trivial_cfg_a_op %arg0, %arg1 + test.format_operand_trivial_cfg_a_op %arg0, %arg1 + + // CHECK: test.format_operand_trivial_cfg_b_op %arg0, %arg1 + test.format_operand_trivial_cfg_b_op %arg0, %arg1 + + return +} + // CHECK: test.format_variadic_operand %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64 test.format_variadic_operand %i64, %i64, %i64 : i64, i64, i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -422,11 +422,16 @@ Optional getVarTransformer() const { return variableTransformer; } + bool shouldSkipTypeDefinition() const { return skipTypeDefinition; } void setResolver(ConstArgument arg, Optional transformer) { resolver = arg; variableTransformer = transformer; assert(getVariable() || getAttribute()); } + void setSkipTypeDefinition(const NamedTypeConstraint *arg) { + resolver = arg; + skipTypeDefinition = true; + } private: /// If the type is resolved with a buildable type, this is the index into @@ -438,6 +443,8 @@ /// If the type is resolved based upon another operand or result, this is /// a transformer to apply to the variable when resolving. Optional variableTransformer; + /// Skip explicit type definition for the operand for trivial CFG case. + bool skipTypeDefinition = false; }; OperationFormat(const Operator &op) @@ -452,6 +459,9 @@ hasSingleBlockTrait = hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock"); + + hasUsedOnlyInTrivialSSACFGTrait = + op.getTrait("::mlir::OpTrait::UsedOnlyInTrivialSSACFG"); } /// Generate the operation parser from this format. @@ -493,6 +503,9 @@ /// A flag indicating if this operation has the SingleBlock trait. bool hasSingleBlockTrait; + /// A flag indicating if this operation has the UsedOnlyInTrivialSSACFG trait. + bool hasUsedOnlyInTrivialSSACFGTrait; + /// A map of buildable types to indices. llvm::MapVector> buildableTypes; @@ -1190,6 +1203,20 @@ for (auto &element : elements) genElementParserStorage(&*element, body); + // Generate empty types storages for operands in trivial CFG case. + if (hasUsedOnlyInTrivialSSACFGTrait) { + for (TypeResolution &resolver : operandTypes) { + if (resolver.shouldSkipTypeDefinition()) { + const NamedTypeConstraint *var = resolver.getVariable(); + + body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", var->name) + << llvm::formatv( + " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", + var->name); + } + } + } + // A format context used when parsing attributes with buildable types. FmtContext attrTypeCtx; attrTypeCtx.withBuilder("parser.getBuilder()"); @@ -2675,7 +2702,8 @@ // Similarly to results, allow a custom builder for resolving the type if // we aren't using the 'operands' directive. Optional builder = operand.constraint.getBuilderCall(); - if (!builder || (fmt.allOperands && operand.isVariableLength())) { + if ((!builder || (fmt.allOperands && operand.isVariableLength())) && + !fmt.hasUsedOnlyInTrivialSSACFGTrait) { return emitErrorAndNote( loc, "type of operand #" + Twine(i) + ", named '" + operand.name + @@ -2684,8 +2712,12 @@ "'type($" + operand.name + ")' directive to the " + "custom assembly format"); } - auto it = buildableTypes.insert({*builder, buildableTypes.size()}); - fmt.operandTypes[i].setBuilderIdx(it.first->second); + if (builder) { + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.operandTypes[i].setBuilderIdx(it.first->second); + } else if (fmt.hasUsedOnlyInTrivialSSACFGTrait) { + fmt.operandTypes[i].setSkipTypeDefinition(&operand); + } } return ::mlir::success(); }