diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -36,6 +36,11 @@ mlir_tablegen(TestPatterns.inc -gen-rewriters) add_public_tablegen_target(MLIRTestOpsIncGen) +set(LLVM_TARGET_DEFINITIONS TestOpsSyntax.td) +mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls) +mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTestOpsSyntaxIncGen) + # Exclude tests from libMLIR.so add_mlir_library(MLIRTestDialect TestAttributes.cpp @@ -44,6 +49,8 @@ TestPatterns.cpp TestTraits.cpp TestTypes.cpp + TestOpsSyntax.cpp + TestDialectInterfaces.cpp EXCLUDE_FROM_LIBMLIR @@ -53,6 +60,7 @@ MLIRTestInterfaceIncGen MLIRTestTypeDefIncGen MLIRTestOpsIncGen + MLIRTestOpsSyntaxIncGen LINK_LIBS PUBLIC MLIRControlFlowInterfaces diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -28,7 +28,6 @@ #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" @@ -120,357 +119,6 @@ registry.insert(); } -//===----------------------------------------------------------------------===// -// TestDialect version utilities -//===----------------------------------------------------------------------===// - -struct TestDialectVersion : public DialectVersion { - uint32_t major = 2; - uint32_t minor = 0; -}; - -//===----------------------------------------------------------------------===// -// TestDialect Interfaces -//===----------------------------------------------------------------------===// - -namespace { - -/// Testing the correctness of some traits. -static_assert( - llvm::is_detected::value, - "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); -static_assert(OpTrait::hasSingleBlockImplicitTerminator< - SingleBlockImplicitTerminatorOp>::value, - "hasSingleBlockImplicitTerminator does not match " - "SingleBlockImplicitTerminatorOp"); - -struct TestResourceBlobManagerInterface - : public ResourceBlobManagerDialectInterfaceBase< - TestDialectResourceBlobHandle> { - using ResourceBlobManagerDialectInterfaceBase< - TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; -}; - -namespace { -enum test_encoding { k_attr_params = 0 }; -} - -// Test support for interacting with the Bytecode reader/writer. -struct TestBytecodeDialectInterface : public BytecodeDialectInterface { - using BytecodeDialectInterface::BytecodeDialectInterface; - TestBytecodeDialectInterface(Dialect *dialect) - : BytecodeDialectInterface(dialect) {} - - LogicalResult writeAttribute(Attribute attr, - DialectBytecodeWriter &writer) const final { - if (auto concreteAttr = llvm::dyn_cast(attr)) { - writer.writeVarInt(test_encoding::k_attr_params); - writer.writeVarInt(concreteAttr.getV0()); - writer.writeVarInt(concreteAttr.getV1()); - return success(); - } - return failure(); - } - - Attribute readAttribute(DialectBytecodeReader &reader, - const DialectVersion &version_) const final { - const auto &version = static_cast(version_); - if (version.major < 2) - return readAttrOldEncoding(reader); - if (version.major == 2 && version.minor == 0) - return readAttrNewEncoding(reader); - // Forbid reading future versions by returning nullptr. - return Attribute(); - } - - // Emit a specific version of the dialect. - void writeVersion(DialectBytecodeWriter &writer) const final { - auto version = TestDialectVersion(); - writer.writeVarInt(version.major); // major - writer.writeVarInt(version.minor); // minor - } - - std::unique_ptr - readVersion(DialectBytecodeReader &reader) const final { - uint64_t major, minor; - if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor))) - return nullptr; - auto version = std::make_unique(); - version->major = major; - version->minor = minor; - return version; - } - - LogicalResult upgradeFromVersion(Operation *topLevelOp, - const DialectVersion &version_) const final { - const auto &version = static_cast(version_); - if ((version.major == 2) && (version.minor == 0)) - return success(); - if (version.major > 2 || (version.major == 2 && version.minor > 0)) { - return topLevelOp->emitError() - << "current test dialect version is 2.0, can't parse version: " - << version.major << "." << version.minor; - } - // Prior version 2.0, the old op supported only a single attribute called - // "dimensions". We can perform the upgrade. - topLevelOp->walk([](TestVersionedOpA op) { - if (auto dims = op->getAttr("dimensions")) { - op->removeAttr("dimensions"); - op->setAttr("dims", dims); - } - op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); - }); - return success(); - } - -private: - Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const { - uint64_t encoding; - if (failed(reader.readVarInt(encoding)) || - encoding != test_encoding::k_attr_params) - return Attribute(); - // The new encoding has v0 first, v1 second. - uint64_t v0, v1; - if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) - return Attribute(); - return TestAttrParamsAttr::get(getContext(), static_cast(v0), - static_cast(v1)); - } - - Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const { - uint64_t encoding; - if (failed(reader.readVarInt(encoding)) || - encoding != test_encoding::k_attr_params) - return Attribute(); - // The old encoding has v1 first, v0 second. - uint64_t v0, v1; - if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0))) - return Attribute(); - return TestAttrParamsAttr::get(getContext(), static_cast(v0), - static_cast(v1)); - } -}; - -// Test support for interacting with the AsmPrinter. -struct TestOpAsmInterface : public OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; - TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr) - : OpAsmDialectInterface(dialect), blobManager(mgr) {} - - //===------------------------------------------------------------------===// - // Aliases - //===------------------------------------------------------------------===// - - AliasResult getAlias(Attribute attr, raw_ostream &os) const final { - StringAttr strAttr = dyn_cast(attr); - if (!strAttr) - return AliasResult::NoAlias; - - // Check the contents of the string attribute to see what the test alias - // should be named. - std::optional aliasName = - StringSwitch>(strAttr.getValue()) - .Case("alias_test:dot_in_name", StringRef("test.alias")) - .Case("alias_test:trailing_digit", StringRef("test_alias0")) - .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) - .Case("alias_test:sanitize_conflict_a", - StringRef("test_alias_conflict0")) - .Case("alias_test:sanitize_conflict_b", - StringRef("test_alias_conflict0_")) - .Case("alias_test:tensor_encoding", StringRef("test_encoding")) - .Default(std::nullopt); - if (!aliasName) - return AliasResult::NoAlias; - - os << *aliasName; - return AliasResult::FinalAlias; - } - - AliasResult getAlias(Type type, raw_ostream &os) const final { - if (auto tupleType = dyn_cast(type)) { - if (tupleType.size() > 0 && - llvm::all_of(tupleType.getTypes(), [](Type elemType) { - return isa(elemType); - })) { - os << "test_tuple"; - return AliasResult::FinalAlias; - } - } - if (auto intType = dyn_cast(type)) { - if (intType.getSignedness() == - TestIntegerType::SignednessSemantics::Unsigned && - intType.getWidth() == 8) { - os << "test_ui8"; - return AliasResult::FinalAlias; - } - } - if (auto recType = dyn_cast(type)) { - if (recType.getName() == "type_to_alias") { - // We only make alias for a specific recursive type. - os << "testrec"; - return AliasResult::FinalAlias; - } - } - return AliasResult::NoAlias; - } - - //===------------------------------------------------------------------===// - // Resources - //===------------------------------------------------------------------===// - - std::string - getResourceKey(const AsmDialectResourceHandle &handle) const override { - return cast(handle).getKey().str(); - } - - FailureOr - declareResource(StringRef key) const final { - return blobManager.insert(key); - } - - LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { - FailureOr blob = entry.parseAsBlob(); - if (failed(blob)) - return failure(); - - // Update the blob for this entry. - blobManager.update(entry.getKey(), std::move(*blob)); - return success(); - } - - void - buildResources(Operation *op, - const SetVector &referencedResources, - AsmResourceBuilder &provider) const final { - blobManager.buildResources(provider, referencedResources.getArrayRef()); - } - -private: - /// The blob manager for the dialect. - TestResourceBlobManagerInterface &blobManager; -}; - -struct TestDialectFoldInterface : public DialectFoldInterface { - using DialectFoldInterface::DialectFoldInterface; - - /// Registered hook to check if the given region, which is attached to an - /// operation that is *not* isolated from above, should be used when - /// materializing constants. - bool shouldMaterializeInto(Region *region) const final { - // If this is a one region operation, then insert into it. - return isa(region->getParentOp()); - } -}; - -/// This class defines the interface for handling inlining with standard -/// operations. -struct TestInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - // Don't allow inlining calls that are marked `noinline`. - return !call->hasAttr("noinline"); - } - bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { - // Inlining into test dialect regions is legal. - return true; - } - bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { - return true; - } - - bool shouldAnalyzeRecursively(Operation *op) const final { - // Analyze recursively if this is not a functional region operation, it - // froms a separate functional scope. - return !isa(op); - } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { - // Only handle "test.return" here. - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } - - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - // Only allow conversion for i16/i32 types. - if (!(resultType.isSignlessInteger(16) || - resultType.isSignlessInteger(32)) || - !(input.getType().isSignlessInteger(16) || - input.getType().isSignlessInteger(32))) - return nullptr; - return builder.create(conversionLoc, resultType, input); - } - - Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, - Value argument, - DictionaryAttr argumentAttrs) const final { - if (!argumentAttrs.contains("test.handle_argument")) - return argument; - return builder.create(call->getLoc(), argument.getType(), - argument); - } - - Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, - Value result, DictionaryAttr resultAttrs) const final { - if (!resultAttrs.contains("test.handle_result")) - return result; - return builder.create(call->getLoc(), result.getType(), - result); - } - - void processInlinedCallBlocks( - Operation *call, - iterator_range inlinedBlocks) const final { - if (!isa(call)) - return; - - // Set attributed on all ops in the inlined blocks. - for (Block &block : inlinedBlocks) { - block.walk([&](Operation *op) { - op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); - }); - } - } -}; - -struct TestReductionPatternInterface : public DialectReductionPatternInterface { -public: - TestReductionPatternInterface(Dialect *dialect) - : DialectReductionPatternInterface(dialect) {} - - void populateReductionPatterns(RewritePatternSet &patterns) const final { - populateTestReductionPatterns(patterns); - } -}; - -} // namespace - //===----------------------------------------------------------------------===// // Dynamic operations //===----------------------------------------------------------------------===// @@ -557,16 +205,12 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + registerOpsSyntax(); addOperations(); registerDynamicOp(getDynamicGenericOp(this)); registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); - - auto &blobInterface = addInterface(); - addInterface(blobInterface); - - addInterfaces(); + registerInterfaces(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific @@ -583,15 +227,6 @@ return builder.create(loc, type, value); } -::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( - ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, - ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, - OpaqueProperties properties, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); - return ::mlir::success(); -} - void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, OperationName opName) { if (opName.getIdentifier() == "test.unregistered_side_effect_op" && @@ -785,224 +420,6 @@ results.add(context); } -//===----------------------------------------------------------------------===// -// Test Format* operations -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// Parsing - -static ParseResult parseCustomOptionalOperand( - OpAsmParser &parser, - std::optional &optOperand) { - if (succeeded(parser.parseOptionalLParen())) { - optOperand.emplace(); - if (parser.parseOperand(*optOperand) || parser.parseRParen()) - return failure(); - } - return success(); -} - -static ParseResult parseCustomDirectiveOperands( - OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, - std::optional &optOperand, - SmallVectorImpl &varOperands) { - if (parser.parseOperand(operand)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - optOperand.emplace(); - if (parser.parseOperand(*optOperand)) - return failure(); - } - if (parser.parseArrow() || parser.parseLParen() || - parser.parseOperandList(varOperands) || parser.parseRParen()) - return failure(); - return success(); -} -static ParseResult -parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, - Type &optOperandType, - SmallVectorImpl &varOperandTypes) { - if (parser.parseColon()) - return failure(); - - if (parser.parseType(operandType)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseType(optOperandType)) - return failure(); - } - if (parser.parseArrow() || parser.parseLParen() || - parser.parseTypeList(varOperandTypes) || parser.parseRParen()) - return failure(); - return success(); -} -static ParseResult -parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, - Type optOperandType, - const SmallVectorImpl &varOperandTypes) { - if (parser.parseKeyword("type_refs_capture")) - return failure(); - - Type operandType2, optOperandType2; - SmallVector varOperandTypes2; - if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, - varOperandTypes2)) - return failure(); - - if (operandType != operandType2 || optOperandType != optOperandType2 || - varOperandTypes != varOperandTypes2) - return failure(); - - return success(); -} -static ParseResult parseCustomDirectiveOperandsAndTypes( - OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, - std::optional &optOperand, - SmallVectorImpl &varOperands, - Type &operandType, Type &optOperandType, - SmallVectorImpl &varOperandTypes) { - if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || - parseCustomDirectiveResults(parser, operandType, optOperandType, - varOperandTypes)) - return failure(); - return success(); -} -static ParseResult parseCustomDirectiveRegions( - OpAsmParser &parser, Region ®ion, - SmallVectorImpl> &varRegions) { - if (parser.parseRegion(region)) - return failure(); - if (failed(parser.parseOptionalComma())) - return success(); - std::unique_ptr varRegion = std::make_unique(); - if (parser.parseRegion(*varRegion)) - return failure(); - varRegions.emplace_back(std::move(varRegion)); - return success(); -} -static ParseResult -parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, - SmallVectorImpl &varSuccessors) { - if (parser.parseSuccessor(successor)) - return failure(); - if (failed(parser.parseOptionalComma())) - return success(); - Block *varSuccessor; - if (parser.parseSuccessor(varSuccessor)) - return failure(); - varSuccessors.append(2, varSuccessor); - return success(); -} -static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, - IntegerAttr &attr, - IntegerAttr &optAttr) { - if (parser.parseAttribute(attr)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(optAttr)) - return failure(); - } - return success(); -} -static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser, - mlir::StringAttr &attr) { - return parser.parseAttribute(attr); -} -static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, - NamedAttrList &attrs) { - return parser.parseOptionalAttrDict(attrs); -} -static ParseResult parseCustomDirectiveOptionalOperandRef( - OpAsmParser &parser, - std::optional &optOperand) { - int64_t operandCount = 0; - if (parser.parseInteger(operandCount)) - return failure(); - bool expectedOptionalOperand = operandCount == 0; - return success(expectedOptionalOperand != optOperand.has_value()); -} - -//===----------------------------------------------------------------------===// -// Printing - -static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, - Value optOperand) { - if (optOperand) - printer << "(" << optOperand << ") "; -} - -static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, - Value operand, Value optOperand, - OperandRange varOperands) { - printer << operand; - if (optOperand) - printer << ", " << optOperand; - printer << " -> (" << varOperands << ")"; -} -static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, - Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printer << " : " << operandType; - if (optOperandType) - printer << ", " << optOperandType; - printer << " -> (" << varOperandTypes << ")"; -} -static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, - Operation *op, Type operandType, - Type optOperandType, - TypeRange varOperandTypes) { - printer << " type_refs_capture "; - printCustomDirectiveResults(printer, op, operandType, optOperandType, - varOperandTypes); -} -static void printCustomDirectiveOperandsAndTypes( - OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, - OperandRange varOperands, Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); - printCustomDirectiveResults(printer, op, operandType, optOperandType, - varOperandTypes); -} -static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, - Region ®ion, - MutableArrayRef varRegions) { - printer.printRegion(region); - if (!varRegions.empty()) { - printer << ", "; - for (Region ®ion : varRegions) - printer.printRegion(region); - } -} -static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, - Block *successor, - SuccessorRange varSuccessors) { - printer << successor; - if (!varSuccessors.empty()) - printer << ", " << varSuccessors.front(); -} -static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, - Attribute attribute, - Attribute optAttribute) { - printer << attribute; - if (optAttribute) - printer << ", " << optAttribute; -} -static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op, - Attribute attribute) { - printer << attribute; -} -static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs) { - printer.printOptionalAttrDict(attrs.getValue()); -} - -static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, - Operation *op, - Value optOperand) { - printer << (optOperand ? "1" : "0"); -} - //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// @@ -1060,249 +477,6 @@ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } -//===----------------------------------------------------------------------===// -// Test parser. -//===----------------------------------------------------------------------===// - -ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parser.parseOptionalColon()) - return success(); - uint64_t numResults; - if (parser.parseInteger(numResults)) - return failure(); - - IndexType type = parser.getBuilder().getIndexType(); - for (unsigned i = 0; i < numResults; ++i) - result.addTypes(type); - return success(); -} - -void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { - if (unsigned numResults = getNumResults()) - p << " : " << numResults; -} - -ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, - OperationState &result) { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return failure(); - result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); - return success(); -} - -void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } - -ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, - OperationState &result) { - std::vector bytes; - if (parser.parseBase64Bytes(&bytes)) - return failure(); - result.addAttribute("b64", parser.getBuilder().getStringAttr( - StringRef(&bytes.front(), bytes.size()))); - return success(); -} - -void ParseB64BytesOp::print(OpAsmPrinter &p) { - p << " \"" << llvm::encodeBase64(getB64()) << "\""; -} - -//===----------------------------------------------------------------------===// -// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. - -ParseResult WrappingRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parser.parseKeyword("wraps")) - return failure(); - - // Parse the wrapped op in a region - Region &body = *result.addRegion(); - body.push_back(new Block); - Block &block = body.back(); - Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); - if (!wrappedOp) - return failure(); - - // Create a return terminator in the inner region, pass as operand to the - // terminator the returned values from the wrapped operation. - SmallVector returnOperands(wrappedOp->getResults()); - OpBuilder builder(parser.getContext()); - builder.setInsertionPointToEnd(&block); - builder.create(wrappedOp->getLoc(), returnOperands); - - // Get the results type for the wrapping op from the terminator operands. - Operation &returnOp = body.back().back(); - result.types.append(returnOp.operand_type_begin(), - returnOp.operand_type_end()); - - // Use the location of the wrapped op for the "test.wrapping_region" op. - result.location = wrappedOp->getLoc(); - - return success(); -} - -void WrappingRegionOp::print(OpAsmPrinter &p) { - p << " wraps "; - p.printGenericOp(&getRegion().front().front()); -} - -//===----------------------------------------------------------------------===// -// Test PrettyPrintedRegionOp - exercising the following parser APIs -// parseGenericOperationAfterOpName -// parseCustomOperationName -//===----------------------------------------------------------------------===// - -ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - - SMLoc loc = parser.getCurrentLocation(); - Location currLocation = parser.getEncodedSourceLoc(loc); - - // Parse the operands. - SmallVector operands; - if (parser.parseOperandList(operands)) - return failure(); - - // Check if we are parsing the pretty-printed version - // test.pretty_printed_region start end : - // Else fallback to parsing the "non pretty-printed" version. - if (!succeeded(parser.parseOptionalKeyword("start"))) - return parser.parseGenericOperationAfterOpName(result, - llvm::ArrayRef(operands)); - - FailureOr parseOpNameInfo = parser.parseCustomOperationName(); - if (failed(parseOpNameInfo)) - return failure(); - - StringAttr innerOpName = parseOpNameInfo->getIdentifier(); - - FunctionType opFntype; - std::optional explicitLoc; - if (parser.parseKeyword("end") || parser.parseColon() || - parser.parseType(opFntype) || - parser.parseOptionalLocationSpecifier(explicitLoc)) - return failure(); - - // If location of the op is explicitly provided, then use it; Else use - // the parser's current location. - Location opLoc = explicitLoc.value_or(currLocation); - - // Derive the SSA-values for op's operands. - if (parser.resolveOperands(operands, opFntype.getInputs(), loc, - result.operands)) - return failure(); - - // Add a region for op. - Region ®ion = *result.addRegion(); - - // Create a basic-block inside op's region. - Block &block = region.emplaceBlock(); - - // Create and insert an "inner-op" operation in the block. - // Just for testing purposes, we can assume that inner op is a binary op with - // result and operand types all same as the test-op's first operand. - Type innerOpType = opFntype.getInput(0); - Value lhs = block.addArgument(innerOpType, opLoc); - Value rhs = block.addArgument(innerOpType, opLoc); - - OpBuilder builder(parser.getBuilder().getContext()); - builder.setInsertionPointToStart(&block); - - Operation *innerOp = - builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); - - // Insert a return statement in the block returning the inner-op's result. - builder.create(innerOp->getLoc(), innerOp->getResults()); - - // Populate the op operation-state with result-type and location. - result.addTypes(opFntype.getResults()); - result.location = innerOp->getLoc(); - - return success(); -} - -void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { - p << ' '; - p.printOperands(getOperands()); - - Operation &innerOp = getRegion().front().front(); - // Assuming that region has a single non-terminator inner-op, if the inner-op - // meets some criteria (which in this case is a simple one based on the name - // of inner-op), then we can print the entire region in a succinct way. - // Here we assume that the prototype of "test.special.op" can be trivially - // derived while parsing it back. - if (innerOp.getName().getStringRef().equals("test.special.op")) { - p << " start test.special.op end"; - } else { - p << " ("; - p.printRegion(getRegion()); - p << ")"; - } - - p << " : "; - p.printFunctionalType(*this); -} - -//===----------------------------------------------------------------------===// -// Test PolyForOp - parse list of region arguments. -//===----------------------------------------------------------------------===// - -ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector ivsInfo; - // Parse list of region arguments without a delimiter. - if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) - return failure(); - - // Parse the body region. - Region *body = result.addRegion(); - for (auto &iv : ivsInfo) - iv.type = parser.getBuilder().getIndexType(); - return parser.parseRegion(*body, ivsInfo); -} - -void PolyForOp::print(OpAsmPrinter &p) { - p << " "; - llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) { - p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true); - }); - p << " "; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -void PolyForOp::getAsmBlockArgumentNames(Region ®ion, - OpAsmSetValueNameFn setNameFn) { - auto arrayAttr = getOperation()->getAttrOfType("arg_names"); - if (!arrayAttr) - return; - auto args = getRegion().front().getArguments(); - auto e = std::min(arrayAttr.size(), args.size()); - for (unsigned i = 0; i < e; ++i) { - if (auto strAttr = dyn_cast(arrayAttr[i])) - setNameFn(args[i], strAttr.getValue()); - } -} - -//===----------------------------------------------------------------------===// -// TestAttrWithLoc - parse/printOptionalLocationSpecifier -//===----------------------------------------------------------------------===// - -static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) { - std::optional result; - SMLoc sourceLoc = p.getCurrentLocation(); - if (p.parseOptionalLocationSpecifier(result)) - return failure(); - if (result) - loc = *result; - else - loc = p.getEncodedSourceLoc(sourceLoc); - return success(); -} - -static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { - p.printOptionalLocationSpecifier(cast(loc)); -} - //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -29,7 +29,9 @@ let extraClassDeclaration = [{ void registerAttributes(); + void registerInterfaces(); void registerTypes(); + void registerOpsSyntax(); // Provides a custom printing/parsing for some operations. ::std::optional diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -0,0 +1,374 @@ +//===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Interfaces/FoldInterfaces.h" +#include "mlir/Reducer/ReductionPatternInterface.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : public DialectVersion { + uint32_t major = 2; + uint32_t minor = 0; +}; + +//===----------------------------------------------------------------------===// +// TestDialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { + +/// Testing the correctness of some traits. +static_assert( + llvm::is_detected::value, + "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); +static_assert(OpTrait::hasSingleBlockImplicitTerminator< + SingleBlockImplicitTerminatorOp>::value, + "hasSingleBlockImplicitTerminator does not match " + "SingleBlockImplicitTerminatorOp"); + +struct TestResourceBlobManagerInterface + : public ResourceBlobManagerDialectInterfaceBase< + TestDialectResourceBlobHandle> { + using ResourceBlobManagerDialectInterfaceBase< + TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; +}; + +namespace { +enum test_encoding { k_attr_params = 0 }; +} + +// Test support for interacting with the Bytecode reader/writer. +struct TestBytecodeDialectInterface : public BytecodeDialectInterface { + using BytecodeDialectInterface::BytecodeDialectInterface; + TestBytecodeDialectInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const final { + if (auto concreteAttr = llvm::dyn_cast(attr)) { + writer.writeVarInt(test_encoding::k_attr_params); + writer.writeVarInt(concreteAttr.getV0()); + writer.writeVarInt(concreteAttr.getV1()); + return success(); + } + return failure(); + } + + Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if (version.major < 2) + return readAttrOldEncoding(reader); + if (version.major == 2 && version.minor == 0) + return readAttrNewEncoding(reader); + // Forbid reading future versions by returning nullptr. + return Attribute(); + } + + // Emit a specific version of the dialect. + void writeVersion(DialectBytecodeWriter &writer) const final { + auto version = TestDialectVersion(); + writer.writeVarInt(version.major); // major + writer.writeVarInt(version.minor); // minor + } + + std::unique_ptr + readVersion(DialectBytecodeReader &reader) const final { + uint64_t major, minor; + if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor))) + return nullptr; + auto version = std::make_unique(); + version->major = major; + version->minor = minor; + return version; + } + + LogicalResult upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if ((version.major == 2) && (version.minor == 0)) + return success(); + if (version.major > 2 || (version.major == 2 && version.minor > 0)) { + return topLevelOp->emitError() + << "current test dialect version is 2.0, can't parse version: " + << version.major << "." << version.minor; + } + // Prior version 2.0, the old op supported only a single attribute called + // "dimensions". We can perform the upgrade. + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + }); + return success(); + } + +private: + Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The new encoding has v0 first, v1 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } + + Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The old encoding has v1 first, v0 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } +}; + +// Test support for interacting with the AsmPrinter. +struct TestOpAsmInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr) + : OpAsmDialectInterface(dialect), blobManager(mgr) {} + + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + + AliasResult getAlias(Attribute attr, raw_ostream &os) const final { + StringAttr strAttr = dyn_cast(attr); + if (!strAttr) + return AliasResult::NoAlias; + + // Check the contents of the string attribute to see what the test alias + // should be named. + std::optional aliasName = + StringSwitch>(strAttr.getValue()) + .Case("alias_test:dot_in_name", StringRef("test.alias")) + .Case("alias_test:trailing_digit", StringRef("test_alias0")) + .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) + .Case("alias_test:sanitize_conflict_a", + StringRef("test_alias_conflict0")) + .Case("alias_test:sanitize_conflict_b", + StringRef("test_alias_conflict0_")) + .Case("alias_test:tensor_encoding", StringRef("test_encoding")) + .Default(std::nullopt); + if (!aliasName) + return AliasResult::NoAlias; + + os << *aliasName; + return AliasResult::FinalAlias; + } + + AliasResult getAlias(Type type, raw_ostream &os) const final { + if (auto tupleType = dyn_cast(type)) { + if (tupleType.size() > 0 && + llvm::all_of(tupleType.getTypes(), [](Type elemType) { + return isa(elemType); + })) { + os << "test_tuple"; + return AliasResult::FinalAlias; + } + } + if (auto intType = dyn_cast(type)) { + if (intType.getSignedness() == + TestIntegerType::SignednessSemantics::Unsigned && + intType.getWidth() == 8) { + os << "test_ui8"; + return AliasResult::FinalAlias; + } + } + if (auto recType = dyn_cast(type)) { + if (recType.getName() == "type_to_alias") { + // We only make alias for a specific recursive type. + os << "testrec"; + return AliasResult::FinalAlias; + } + } + return AliasResult::NoAlias; + } + + //===------------------------------------------------------------------===// + // Resources + //===------------------------------------------------------------------===// + + std::string + getResourceKey(const AsmDialectResourceHandle &handle) const override { + return cast(handle).getKey().str(); + } + + FailureOr + declareResource(StringRef key) const final { + return blobManager.insert(key); + } + + LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { + FailureOr blob = entry.parseAsBlob(); + if (failed(blob)) + return failure(); + + // Update the blob for this entry. + blobManager.update(entry.getKey(), std::move(*blob)); + return success(); + } + + void + buildResources(Operation *op, + const SetVector &referencedResources, + AsmResourceBuilder &provider) const final { + blobManager.buildResources(provider, referencedResources.getArrayRef()); + } + +private: + /// The blob manager for the dialect. + TestResourceBlobManagerInterface &blobManager; +}; + +struct TestDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; + + /// Registered hook to check if the given region, which is attached to an + /// operation that is *not* isolated from above, should be used when + /// materializing constants. + bool shouldMaterializeInto(Region *region) const final { + // If this is a one region operation, then insert into it. + return isa(region->getParentOp()); + } +}; + +/// This class defines the interface for handling inlining with standard +/// operations. +struct TestInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + // Don't allow inlining calls that are marked `noinline`. + return !call->hasAttr("noinline"); + } + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + // Inlining into test dialect regions is legal. + return true; + } + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + bool shouldAnalyzeRecursively(Operation *op) const final { + // Analyze recursively if this is not a functional region operation, it + // froms a separate functional scope. + return !isa(op); + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only handle "test.return" here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } + + /// Attempt to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value input, + Type resultType, + Location conversionLoc) const final { + // Only allow conversion for i16/i32 types. + if (!(resultType.isSignlessInteger(16) || + resultType.isSignlessInteger(32)) || + !(input.getType().isSignlessInteger(16) || + input.getType().isSignlessInteger(32))) + return nullptr; + return builder.create(conversionLoc, resultType, input); + } + + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, + DictionaryAttr argumentAttrs) const final { + if (!argumentAttrs.contains("test.handle_argument")) + return argument; + return builder.create(call->getLoc(), argument.getType(), + argument); + } + + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, DictionaryAttr resultAttrs) const final { + if (!resultAttrs.contains("test.handle_result")) + return result; + return builder.create(call->getLoc(), result.getType(), + result); + } + + void processInlinedCallBlocks( + Operation *call, + iterator_range inlinedBlocks) const final { + if (!isa(call)) + return; + + // Set attributed on all ops in the inlined blocks. + for (Block &block : inlinedBlocks) { + block.walk([&](Operation *op) { + op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); + }); + } + } +}; + +struct TestReductionPatternInterface : public DialectReductionPatternInterface { +public: + TestReductionPatternInterface(Dialect *dialect) + : DialectReductionPatternInterface(dialect) {} + + void populateReductionPatterns(RewritePatternSet &patterns) const final { + populateTestReductionPatterns(patterns); + } +}; + +} // namespace + +void TestDialect::registerInterfaces() { + auto &blobInterface = addInterface(); + addInterface(blobInterface); + + addInterfaces(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -859,72 +859,7 @@ let results = (outs Variadic:$a, I32:$b, Optional:$c); } -// This is used to test that the fallback for a custom op's parser and printer -// is the dialect parser and printer hooks. -def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; - -// Ops related to OIList primitive -def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { - let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword, - UnitAttr:$diffNameUnitAttrKeyword); - let assemblyFormat = [{ - oilist( `keyword` $keyword - | `otherKeyword` $otherKeyword - | `thirdKeyword` $diffNameUnitAttrKeyword) attr-dict - }]; -} - -def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { - let arguments = (ins Optional:$arg0, - Optional:$arg1, - Optional:$arg2); - let assemblyFormat = [{ - oilist( `keyword` $arg0 `:` type($arg0) - | `otherKeyword` $arg1 `:` type($arg1) - | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict - }]; -} - -def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> { - let arguments = (ins Variadic:$arg0, - Variadic:$arg1, - Variadic:$arg2); - let assemblyFormat = [{ - oilist( `keyword` `(` $arg0 `:` type($arg0) `)` - | `otherKeyword` `(` $arg1 `:` type($arg1) `)` - | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict - }]; -} - -def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> { - let arguments = (ins Variadic:$arg0, - Optional:$optOperand, - UnitAttr:$nowait); - let assemblyFormat = [{ - oilist( `private` `(` $arg0 `:` type($arg0) `)` - | `reduction` custom($optOperand) - | `nowait` $nowait - ) attr-dict - }]; -} - -def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> { - let assemblyFormat = [{ - oilist( `foo` | `bar` ) `buzz` attr-dict - }]; -} -def TestEllipsisOp : TEST_Op<"ellipsis"> { - let arguments = (ins Variadic:$operands, UnitAttr:$variadic); - let assemblyFormat = [{ - `(` $operands (`...` $variadic^)? `)` attr-dict `:` type($operands) `...` - }]; -} - -def ElseAnchorOp : TEST_Op<"else_anchor"> { - let arguments = (ins Optional:$a); - let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict"; -} // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. @@ -963,11 +898,6 @@ let assemblyFormat = "regions attr-dict-with-keyword"; } -// This is used to test that the default dialect is not elided when printing an -// op with dots in the name to avoid parsing ambiguity. -def OpWithDotInNameOp : TEST_Op<"op.with_dot_in_name"> { - let assemblyFormat = "attr-dict"; -} // This is used to test the OpAsmOpInterface::getAsmBlockName() feature: // blocks nested in a region under this op will have a name defined by the @@ -1993,56 +1923,6 @@ let hasCustomAssemblyFormat = 1; } -def WrappingRegionOp : TEST_Op<"wrapping_region", - [SingleBlockImplicitTerminator<"TestReturnOp">]> { - let summary = "wrapping region operation"; - let description = [{ - Test op wrapping another op in a region, to test calling - parseGenericOperation from the custom parser. - }]; - - let results = (outs Variadic); - let regions = (region SizedRegion<1>:$region); - let hasCustomAssemblyFormat = 1; -} - -def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", - [SingleBlockImplicitTerminator<"TestReturnOp">]> { - let summary = "pretty_printed_region operation"; - let description = [{ - Test-op can be printed either in a "pretty" or "non-pretty" way based on - some criteria. The custom parser parsers both the versions while testing - APIs: parseCustomOperationName & parseGenericOperationAfterOpName. - }]; - let arguments = (ins - AnyType:$input1, - AnyType:$input2 - ); - - let results = (outs AnyType); - let regions = (region SizedRegion<1>:$region); - let hasCustomAssemblyFormat = 1; -} - -def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> { - let summary = "polyfor operation"; - let description = [{ - Test op with multiple region arguments, each argument of index type. - }]; - let extraClassDeclaration = [{ - void getAsmBlockArgumentNames(mlir::Region ®ion, - mlir::OpAsmSetValueNameFn setNameFn); - }]; - let regions = (region SizedRegion<1>:$region); - let hasCustomAssemblyFormat = 1; -} - -def TestAttrWithLoc : TEST_Op<"attr_with_loc"> { - let summary = "op's attribute has a location"; - let arguments = (ins AnyAttr:$loc, AnyAttr:$value); - let assemblyFormat = "`(` $value `` custom($loc) `)` attr-dict"; -} - //===----------------------------------------------------------------------===// // Test OpAsmInterface. @@ -2055,598 +1935,6 @@ let results = (outs AnyType); } -//===----------------------------------------------------------------------===// -// Test Op Asm Format -//===----------------------------------------------------------------------===// - -def FormatLiteralOp : TEST_Op<"format_literal_op"> { - let assemblyFormat = [{ - `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` - `?` `+` `*` `{` `\n` `}` attr-dict - }]; -} - -// Test that we elide attributes that are within the syntax. -def FormatAttrOp : TEST_Op<"format_attr_op"> { - let arguments = (ins I64Attr:$attr); - let assemblyFormat = "$attr attr-dict"; -} - -// Test that we elide optional attributes that are within the syntax. -def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> { - let arguments = (ins OptionalAttr:$opt_attr); - let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict"; -} -def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> { - let arguments = (ins OptionalAttr:$opt_attr); - let assemblyFormat = "($opt_attr^)? attr-dict"; -} - -// Test that we format symbol name attributes properly. -def FormatSymbolNameAttrOp : TEST_Op<"format_symbol_name_attr_op"> { - let arguments = (ins SymbolNameAttr:$attr); - let assemblyFormat = "$attr attr-dict"; -} - -// Test that we format optional symbol name attributes properly. -def FormatOptSymbolNameAttrOp : TEST_Op<"format_opt_symbol_name_attr_op"> { - let arguments = (ins OptionalAttr:$opt_attr); - let assemblyFormat = "($opt_attr^)? attr-dict"; -} - -// Test that we format optional symbol reference attributes properly. -def FormatOptSymbolRefAttrOp : TEST_Op<"format_opt_symbol_ref_attr_op"> { - let arguments = (ins OptionalAttr:$opt_attr); - let assemblyFormat = "($opt_attr^)? attr-dict"; -} - -// Test that we elide attributes that are within the syntax. -def FormatAttrDictWithKeywordOp : TEST_Op<"format_attr_dict_w_keyword"> { - let arguments = (ins I64Attr:$attr, OptionalAttr:$opt_attr); - let assemblyFormat = "attr-dict-with-keyword"; -} - -// Test that we don't need to provide types in the format if they are buildable. -def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> { - let arguments = (ins I64:$buildable); - let results = (outs I64:$buildable_res); - let assemblyFormat = "$buildable attr-dict"; -} - -// Test various mixings of region formatting. -class FormatRegionBase - : TEST_Op<"format_region_" # suffix # "_op"> { - let regions = (region AnyRegion:$region); - let assemblyFormat = fmt; -} -def FormatRegionAOp : FormatRegionBase<"a", [{ - regions attr-dict -}]>; -def FormatRegionBOp : FormatRegionBase<"b", [{ - $region attr-dict -}]>; -def FormatRegionCOp : FormatRegionBase<"c", [{ - (`region` $region^)? attr-dict -}]>; -class FormatVariadicRegionBase - : TEST_Op<"format_variadic_region_" # suffix # "_op"> { - let regions = (region VariadicRegion:$regions); - let assemblyFormat = fmt; -} -def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{ - $regions attr-dict -}]>; -def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{ - ($regions^ `found_regions`)? attr-dict -}]>; -class FormatRegionImplicitTerminatorBase - : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op", - [SingleBlockImplicitTerminator<"TestReturnOp">]> { - let regions = (region AnyRegion:$region); - let assemblyFormat = fmt; -} -def FormatFormatRegionImplicitTerminatorAOp - : FormatRegionImplicitTerminatorBase<"a", [{ - $region attr-dict -}]>; - -// Test various mixings of result type formatting. -class FormatResultBase - : TEST_Op<"format_result_" # suffix # "_op"> { - let results = (outs I64:$buildable_res, AnyMemRef:$result); - let assemblyFormat = fmt; -} -def FormatResultAOp : FormatResultBase<"a", [{ - type($result) attr-dict -}]>; -def FormatResultBOp : FormatResultBase<"b", [{ - type(results) attr-dict -}]>; -def FormatResultCOp : FormatResultBase<"c", [{ - functional-type($buildable_res, $result) attr-dict -}]>; - -def FormatVariadicResult : TEST_Op<"format_variadic_result"> { - let results = (outs Variadic:$result); - let assemblyFormat = [{ `:` type($result) attr-dict}]; -} - -def FormatMultipleVariadicResults : TEST_Op<"format_multiple_variadic_results", - [AttrSizedResultSegments]> { - let results = (outs Variadic:$result0, Variadic:$result1); - let assemblyFormat = [{ - `:` `(` type($result0) `)` `,` `(` type($result1) `)` attr-dict - }]; -} - -// Test various mixings of operand type formatting. -class FormatOperandBase - : TEST_Op<"format_operand_" # suffix # "_op"> { - let arguments = (ins I64:$buildable, AnyMemRef:$operand); - let assemblyFormat = fmt; -} - -def FormatOperandAOp : FormatOperandBase<"a", [{ - operands `:` type(operands) attr-dict -}]>; -def FormatOperandBOp : FormatOperandBase<"b", [{ - operands `:` type($operand) attr-dict -}]>; -def FormatOperandCOp : FormatOperandBase<"c", [{ - $buildable `,` $operand `:` type(operands) attr-dict -}]>; -def FormatOperandDOp : FormatOperandBase<"d", [{ - $buildable `,` $operand `:` type($operand) attr-dict -}]>; -def FormatOperandEOp : FormatOperandBase<"e", [{ - $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict -}]>; - -def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { - let successors = (successor VariadicSuccessor:$targets); - let assemblyFormat = "$targets attr-dict"; -} - -def FormatVariadicOperand : TEST_Op<"format_variadic_operand"> { - let arguments = (ins Variadic:$operand); - let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; -} -def FormatVariadicOfVariadicOperand - : TEST_Op<"format_variadic_of_variadic_operand"> { - let arguments = (ins - VariadicOfVariadic:$operand, - DenseI32ArrayAttr:$operand_segments - ); - let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; -} - -def FormatMultipleVariadicOperands : - TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> { - let arguments = (ins Variadic:$operand0, Variadic:$operand1); - let assemblyFormat = [{ - ` ` `(` $operand0 `)` `,` `(` $operand1 `:` type($operand1) `)` attr-dict - }]; -} - -// Test various mixings of optional operand and result type formatting. -class FormatOptionalOperandResultOpBase - : TEST_Op<"format_optional_operand_result_" # suffix # "_op", - [AttrSizedOperandSegments]> { - let arguments = (ins Optional:$optional, Variadic:$variadic); - let results = (outs Optional:$optional_res); - let assemblyFormat = fmt; -} - -def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{ - `(` $optional `:` type($optional) `)` `:` type($optional_res) - (`[` $variadic^ `]`)? attr-dict -}]>; - -def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{ - (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res) - (`[` $variadic^ `]`)? attr-dict -}]>; - -// Test optional result type formatting. -class FormatOptionalResultOpBase - : TEST_Op<"format_optional_result_" # suffix # "_op", - [AttrSizedResultSegments]> { - let results = (outs Optional:$optional, Variadic:$variadic); - let assemblyFormat = fmt; -} -def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{ - (`:` type($optional)^ `->` type($variadic))? attr-dict -}]>; - -def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{ - (`:` type($optional) `->` type($variadic)^)? attr-dict -}]>; - -def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{ - (`:` functional-type($optional, $variadic)^)? attr-dict -}]>; - -def FormatOptionalResultDOp - : TEST_Op<"format_optional_result_d_op" > { - let results = (outs Optional:$optional); - let assemblyFormat = "(`:` type($optional)^)? attr-dict"; -} - -def FormatTwoVariadicOperandsNoBuildableTypeOp - : TEST_Op<"format_two_variadic_operands_no_buildable_type_op", - [AttrSizedOperandSegments]> { - let arguments = (ins Variadic:$a, - Variadic:$b); - let assemblyFormat = [{ - `(` $a `:` type($a) `)` `->` `(` $b `:` type($b) `)` attr-dict - }]; -} - -def FormatInferVariadicTypeFromNonVariadic - : TEST_Op<"format_infer_variadic_type_from_non_variadic", - [SameOperandsAndResultType]> { - let arguments = (ins Variadic:$args); - let results = (outs AnyType:$result); - let assemblyFormat = "operands attr-dict `:` type($result)"; -} - -def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> { - let arguments = (ins UnitAttr:$is_optional); - let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict"; -} - -def FormatOptionalUnitAttrNoElide - : TEST_Op<"format_optional_unit_attribute_no_elide"> { - let arguments = (ins UnitAttr:$is_optional); - let assemblyFormat = "($is_optional^)? attr-dict"; -} - -def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> { - let arguments = (ins OptionalAttr:$attr); - let assemblyFormat = "($attr^)? attr-dict"; -} - -def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> { - let arguments = (ins DefaultValuedStrAttr:$str, - DefaultValuedStrAttr:$sym, - DefaultValuedAttr:$e); - let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict"; -} - -def FormatOptionalWithElse : TEST_Op<"format_optional_else"> { - let arguments = (ins UnitAttr:$isFirstBranchPresent); - let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict"; -} - -def FormatCompoundAttr : TEST_Op<"format_compound_attr"> { - let arguments = (ins CompoundAttrA:$compound); - let assemblyFormat = "$compound attr-dict-with-keyword"; -} - -def FormatNestedAttr : TEST_Op<"format_nested_attr"> { - let arguments = (ins CompoundAttrNested:$nested); - let assemblyFormat = "$nested attr-dict-with-keyword"; -} - -def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> { - let arguments = (ins CompoundNestedOuter:$nested); - let assemblyFormat = "`nested` $nested attr-dict-with-keyword"; -} - -def FormatMaybeEmptyType : TEST_Op<"format_maybe_empty_type"> { - let arguments = (ins TestTypeOptionalValueType:$in); - let assemblyFormat = "$in `:` type($in) attr-dict"; -} - -def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> { - let arguments = (ins CompoundNestedOuter:$nested); - let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword"; -} - -def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> { - let arguments = (ins CompoundNestedOuterType:$nested); - let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword"; -} - -def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> { - let arguments = (ins CompoundNestedOuterType:$nested); - let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword"; -} - -//===----------------------------------------------------------------------===// -// Custom Directives - -def FormatCustomDirectiveOperands - : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> { - let arguments = (ins I64:$operand, Optional:$optOperand, - Variadic:$varOperands); - let assemblyFormat = [{ - custom( - $operand, $optOperand, $varOperands - ) - attr-dict - }]; -} - -def FormatCustomDirectiveOperandsAndTypes - : TEST_Op<"format_custom_directive_operands_and_types", - [AttrSizedOperandSegments]> { - let arguments = (ins AnyType:$operand, Optional:$optOperand, - Variadic:$varOperands); - let assemblyFormat = [{ - custom( - $operand, $optOperand, $varOperands, - type($operand), type($optOperand), type($varOperands) - ) - attr-dict - }]; -} - -def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> { - let regions = (region AnyRegion:$region, VariadicRegion:$other_regions); - let assemblyFormat = [{ - custom( - $region, $other_regions - ) - attr-dict - }]; -} - -def FormatCustomDirectiveResults - : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> { - let results = (outs AnyType:$result, Optional:$optResult, - Variadic:$varResults); - let assemblyFormat = [{ - custom( - type($result), type($optResult), type($varResults) - ) - attr-dict - }]; -} - -def FormatCustomDirectiveResultsWithTypeRefs - : TEST_Op<"format_custom_directive_results_with_type_refs", - [AttrSizedResultSegments]> { - let results = (outs AnyType:$result, Optional:$optResult, - Variadic:$varResults); - let assemblyFormat = [{ - custom( - type($result), type($optResult), type($varResults) - ) - custom( - ref(type($result)), ref(type($optResult)), ref(type($varResults)) - ) - attr-dict - }]; -} - -def FormatCustomDirectiveWithOptionalOperandRef - : TEST_Op<"format_custom_directive_with_optional_operand_ref"> { - let arguments = (ins Optional:$optOperand); - let assemblyFormat = [{ - ($optOperand^)? `:` - custom(ref($optOperand)) - attr-dict - }]; -} - -def FormatCustomDirectiveSuccessors - : TEST_Op<"format_custom_directive_successors", [Terminator]> { - let successors = (successor AnySuccessor:$successor, - VariadicSuccessor:$successors); - let assemblyFormat = [{ - custom( - $successor, $successors - ) - attr-dict - }]; -} - -def FormatCustomDirectiveAttributes - : TEST_Op<"format_custom_directive_attributes"> { - let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); - let assemblyFormat = [{ - custom( - $attr, $optAttr - ) - attr-dict - }]; -} - -def FormatCustomDirectiveSpacing - : TEST_Op<"format_custom_directive_spacing"> { - let arguments = (ins StrAttr:$attr1, StrAttr:$attr2); - let assemblyFormat = [{ - custom($attr1) - custom($attr2) - attr-dict - }]; -} - -def FormatCustomDirectiveAttrDict - : TEST_Op<"format_custom_directive_attrdict"> { - let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); - let assemblyFormat = [{ - custom( attr-dict ) - }]; -} - -def FormatLiteralFollowingOptionalGroup - : TEST_Op<"format_literal_following_optional_group"> { - let arguments = (ins TypeAttr:$type, OptionalAttr:$value); - let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict"; -} - -//===----------------------------------------------------------------------===// -// AllTypesMatch type inference - -def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [ - AllTypesMatch<["value1", "value2", "result"]> - ]> { - let arguments = (ins AnyType:$value1, AnyType:$value2); - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)"; -} - -def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [ - AllTypesMatch<["value1", "value2", "result"]> - ]> { - let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2); - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict $value1 `,` $value2"; -} - -//===----------------------------------------------------------------------===// -// TypesMatchWith type inference - -def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ - TypesMatchWith<"result type matches operand", "value", "result", "$_self"> - ]> { - let arguments = (ins AnyType:$value); - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict $value `:` type($value)"; -} - -def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [ - RangedTypesMatchWith<"result type matches operand", "value", "result", - "llvm::make_range($_self.begin(), $_self.end())"> - ]> { - let arguments = (ins Variadic:$value); - let results = (outs Variadic:$result); - let assemblyFormat = "attr-dict $value `:` type($value)"; -} - -def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ - TypesMatchWith<"result type matches constant", "value", "result", "$_self"> - ]> { - let arguments = (ins TypedAttrInterface:$value); - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict $value"; -} - -def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [ - TypesMatchWith<"tuple result type matches operand type", "value", "result", - "::mlir::TupleType::get($_ctxt, $_self)"> - ]> { - let arguments = (ins AnyType:$value); - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict $value `:` type($value)"; -} - -//===----------------------------------------------------------------------===// -// InferTypeOpInterface type inference in assembly format - -def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { - let results = (outs AnyType); - let assemblyFormat = "attr-dict"; - - let extraClassDeclaration = [{ - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); - return ::mlir::success(); - } - }]; -} - -// Check that formatget supports DeclareOpInterfaceMethods. -def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods]> { - let results = (outs AnyType); - let assemblyFormat = "attr-dict"; -} - -// Base class for testing mixing allOperandTypes, allOperands, and -// inferResultTypes. -class FormatInferAllTypesBaseOp traits = []> - : TEST_Op { - let arguments = (ins Variadic:$args); - let results = (outs Variadic:$outs); - let extraClassDeclaration = [{ - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - ::mlir::TypeRange operandTypes = operands.getTypes(); - inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end()); - return ::mlir::success(); - } - }]; -} - -// Test inferReturnTypes is called when allOperandTypes and allOperands is true. -def FormatInferTypeAllOperandsAndTypesOp - : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> { - let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; -} - -// Test inferReturnTypes is called when allOperandTypes is true and there is one -// ODS operand. -def FormatInferTypeAllOperandsAndTypesOneOperandOp - : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> { - let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)"; -} - -// Test inferReturnTypes is called when allOperandTypes is true and there are -// more than one ODS operands. -def FormatInferTypeAllOperandsAndTypesTwoOperandsOp - : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands", - [SameVariadicOperandSize]> { - let arguments = (ins Variadic:$args0, Variadic:$args1); - let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)"; -} - -// Test inferReturnTypes is called when allOperands is true and operand types -// are separately specified. -def FormatInferTypeAllTypesOp - : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> { - let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)"; -} - -// Test inferReturnTypes coupled with regions. -def FormatInferTypeRegionsOp - : TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> { - let results = (outs Variadic:$outs); - let regions = (region AnyRegion:$region); - let assemblyFormat = "$region attr-dict"; - let extraClassDeclaration = [{ - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - if (regions.empty()) - return ::mlir::failure(); - auto types = regions.front()->getArgumentTypes(); - inferredReturnTypes.assign(types.begin(), types.end()); - return ::mlir::success(); - } - }]; -} - -// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes). -def FormatInferTypeVariadicOperandsOp - : TEST_Op<"format_infer_type_variadic_operands", - [InferTypeOpInterface, AttrSizedOperandSegments]> { - let arguments = (ins Variadic:$a, Variadic:$b); - let results = (outs Variadic:$outs); - let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict"; - let extraClassDeclaration = [{ - static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, - ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, - ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - FormatInferTypeVariadicOperandsOpAdaptor adaptor( - operands, attributes, *properties.as(), {}); - auto aTypes = adaptor.getA().getTypes(); - auto bTypes = adaptor.getB().getTypes(); - inferredReturnTypes.append(aTypes.begin(), aTypes.end()); - inferredReturnTypes.append(bTypes.begin(), bTypes.end()); - return ::mlir::success(); - } - }]; -} - //===----------------------------------------------------------------------===// // Test ArrayOfAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.h b/mlir/test/lib/Dialect/Test/TestOpsSyntax.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.h @@ -0,0 +1,26 @@ +//===- TestOpsSyntax.h - Operations for testing syntax ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_DIALECT_TEST_TESTOPSSYNTAX_H +#define MLIR_TEST_DIALECT_TEST_TESTOPSSYNTAX_H + +#include "TestAttributes.h" +#include "TestTypes.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +namespace test { +class TestReturnOp; +} // namespace test + +#define GET_OP_CLASSES +#include "TestOpsSyntax.h.inc" + +#endif // MLIR_TEST_DIALECT_TEST_TESTOPSSYNTAX_H diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp @@ -0,0 +1,494 @@ +//===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestOpsSyntax.h" +#include "TestDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/Support/Base64.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// Test Format* operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Parsing + +static ParseResult parseCustomOptionalOperand( + OpAsmParser &parser, + std::optional &optOperand) { + if (succeeded(parser.parseOptionalLParen())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand) || parser.parseRParen()) + return failure(); + } + return success(); +} + +static ParseResult parseCustomDirectiveOperands( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, + std::optional &optOperand, + SmallVectorImpl &varOperands) { + if (parser.parseOperand(operand)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand)) + return failure(); + } + if (parser.parseArrow() || parser.parseLParen() || + parser.parseOperandList(varOperands) || parser.parseRParen()) + return failure(); + return success(); +} +static ParseResult +parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, + Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parser.parseColon()) + return failure(); + + if (parser.parseType(operandType)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseType(optOperandType)) + return failure(); + } + if (parser.parseArrow() || parser.parseLParen() || + parser.parseTypeList(varOperandTypes) || parser.parseRParen()) + return failure(); + return success(); +} +static ParseResult +parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, + Type optOperandType, + const SmallVectorImpl &varOperandTypes) { + if (parser.parseKeyword("type_refs_capture")) + return failure(); + + Type operandType2, optOperandType2; + SmallVector varOperandTypes2; + if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, + varOperandTypes2)) + return failure(); + + if (operandType != operandType2 || optOperandType != optOperandType2 || + varOperandTypes != varOperandTypes2) + return failure(); + + return success(); +} +static ParseResult parseCustomDirectiveOperandsAndTypes( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, + std::optional &optOperand, + SmallVectorImpl &varOperands, + Type &operandType, Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || + parseCustomDirectiveResults(parser, operandType, optOperandType, + varOperandTypes)) + return failure(); + return success(); +} +static ParseResult parseCustomDirectiveRegions( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl> &varRegions) { + if (parser.parseRegion(region)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + std::unique_ptr varRegion = std::make_unique(); + if (parser.parseRegion(*varRegion)) + return failure(); + varRegions.emplace_back(std::move(varRegion)); + return success(); +} +static ParseResult +parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, + SmallVectorImpl &varSuccessors) { + if (parser.parseSuccessor(successor)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + Block *varSuccessor; + if (parser.parseSuccessor(varSuccessor)) + return failure(); + varSuccessors.append(2, varSuccessor); + return success(); +} +static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, + IntegerAttr &attr, + IntegerAttr &optAttr) { + if (parser.parseAttribute(attr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(optAttr)) + return failure(); + } + return success(); +} +static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser, + mlir::StringAttr &attr) { + return parser.parseAttribute(attr); +} +static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, + NamedAttrList &attrs) { + return parser.parseOptionalAttrDict(attrs); +} +static ParseResult parseCustomDirectiveOptionalOperandRef( + OpAsmParser &parser, + std::optional &optOperand) { + int64_t operandCount = 0; + if (parser.parseInteger(operandCount)) + return failure(); + bool expectedOptionalOperand = operandCount == 0; + return success(expectedOptionalOperand != optOperand.has_value()); +} + +//===----------------------------------------------------------------------===// +// Printing + +static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, + Value optOperand) { + if (optOperand) + printer << "(" << optOperand << ") "; +} + +static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, + Value operand, Value optOperand, + OperandRange varOperands) { + printer << operand; + if (optOperand) + printer << ", " << optOperand; + printer << " -> (" << varOperands << ")"; +} +static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, + Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printer << " : " << operandType; + if (optOperandType) + printer << ", " << optOperandType; + printer << " -> (" << varOperandTypes << ")"; +} +static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, + Operation *op, Type operandType, + Type optOperandType, + TypeRange varOperandTypes) { + printer << " type_refs_capture "; + printCustomDirectiveResults(printer, op, operandType, optOperandType, + varOperandTypes); +} +static void printCustomDirectiveOperandsAndTypes( + OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, + OperandRange varOperands, Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); + printCustomDirectiveResults(printer, op, operandType, optOperandType, + varOperandTypes); +} +static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, + Region ®ion, + MutableArrayRef varRegions) { + printer.printRegion(region); + if (!varRegions.empty()) { + printer << ", "; + for (Region ®ion : varRegions) + printer.printRegion(region); + } +} +static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, + Block *successor, + SuccessorRange varSuccessors) { + printer << successor; + if (!varSuccessors.empty()) + printer << ", " << varSuccessors.front(); +} +static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, + Attribute attribute, + Attribute optAttribute) { + printer << attribute; + if (optAttribute) + printer << ", " << optAttribute; +} +static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op, + Attribute attribute) { + printer << attribute; +} +static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs) { + printer.printOptionalAttrDict(attrs.getValue()); +} + +static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, + Operation *op, + Value optOperand) { + printer << (optOperand ? "1" : "0"); +} +//===----------------------------------------------------------------------===// +// Test parser. +//===----------------------------------------------------------------------===// + +ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalColon()) + return success(); + uint64_t numResults; + if (parser.parseInteger(numResults)) + return failure(); + + IndexType type = parser.getBuilder().getIndexType(); + for (unsigned i = 0; i < numResults; ++i) + result.addTypes(type); + return success(); +} + +void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { + if (unsigned numResults = getNumResults()) + p << " : " << numResults; +} + +ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, + OperationState &result) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); + return success(); +} + +void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } + +ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, + OperationState &result) { + std::vector bytes; + if (parser.parseBase64Bytes(&bytes)) + return failure(); + result.addAttribute("b64", parser.getBuilder().getStringAttr( + StringRef(&bytes.front(), bytes.size()))); + return success(); +} + +void ParseB64BytesOp::print(OpAsmPrinter &p) { + p << " \"" << llvm::encodeBase64(getB64()) << "\""; +} + +::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); + return ::mlir::success(); +} + +//===----------------------------------------------------------------------===// +// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. + +ParseResult WrappingRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseKeyword("wraps")) + return failure(); + + // Parse the wrapped op in a region + Region &body = *result.addRegion(); + body.push_back(new Block); + Block &block = body.back(); + Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); + if (!wrappedOp) + return failure(); + + // Create a return terminator in the inner region, pass as operand to the + // terminator the returned values from the wrapped operation. + SmallVector returnOperands(wrappedOp->getResults()); + OpBuilder builder(parser.getContext()); + builder.setInsertionPointToEnd(&block); + builder.create(wrappedOp->getLoc(), returnOperands); + + // Get the results type for the wrapping op from the terminator operands. + Operation &returnOp = body.back().back(); + result.types.append(returnOp.operand_type_begin(), + returnOp.operand_type_end()); + + // Use the location of the wrapped op for the "test.wrapping_region" op. + result.location = wrappedOp->getLoc(); + + return success(); +} + +void WrappingRegionOp::print(OpAsmPrinter &p) { + p << " wraps "; + p.printGenericOp(&getRegion().front().front()); +} + +//===----------------------------------------------------------------------===// +// Test PrettyPrintedRegionOp - exercising the following parser APIs +// parseGenericOperationAfterOpName +// parseCustomOperationName +//===----------------------------------------------------------------------===// + +ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + + SMLoc loc = parser.getCurrentLocation(); + Location currLocation = parser.getEncodedSourceLoc(loc); + + // Parse the operands. + SmallVector operands; + if (parser.parseOperandList(operands)) + return failure(); + + // Check if we are parsing the pretty-printed version + // test.pretty_printed_region start end : + // Else fallback to parsing the "non pretty-printed" version. + if (!succeeded(parser.parseOptionalKeyword("start"))) + return parser.parseGenericOperationAfterOpName(result, + llvm::ArrayRef(operands)); + + FailureOr parseOpNameInfo = parser.parseCustomOperationName(); + if (failed(parseOpNameInfo)) + return failure(); + + StringAttr innerOpName = parseOpNameInfo->getIdentifier(); + + FunctionType opFntype; + std::optional explicitLoc; + if (parser.parseKeyword("end") || parser.parseColon() || + parser.parseType(opFntype) || + parser.parseOptionalLocationSpecifier(explicitLoc)) + return failure(); + + // If location of the op is explicitly provided, then use it; Else use + // the parser's current location. + Location opLoc = explicitLoc.value_or(currLocation); + + // Derive the SSA-values for op's operands. + if (parser.resolveOperands(operands, opFntype.getInputs(), loc, + result.operands)) + return failure(); + + // Add a region for op. + Region ®ion = *result.addRegion(); + + // Create a basic-block inside op's region. + Block &block = region.emplaceBlock(); + + // Create and insert an "inner-op" operation in the block. + // Just for testing purposes, we can assume that inner op is a binary op with + // result and operand types all same as the test-op's first operand. + Type innerOpType = opFntype.getInput(0); + Value lhs = block.addArgument(innerOpType, opLoc); + Value rhs = block.addArgument(innerOpType, opLoc); + + OpBuilder builder(parser.getBuilder().getContext()); + builder.setInsertionPointToStart(&block); + + Operation *innerOp = + builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); + + // Insert a return statement in the block returning the inner-op's result. + builder.create(innerOp->getLoc(), innerOp->getResults()); + + // Populate the op operation-state with result-type and location. + result.addTypes(opFntype.getResults()); + result.location = innerOp->getLoc(); + + return success(); +} + +void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { + p << ' '; + p.printOperands(getOperands()); + + Operation &innerOp = getRegion().front().front(); + // Assuming that region has a single non-terminator inner-op, if the inner-op + // meets some criteria (which in this case is a simple one based on the name + // of inner-op), then we can print the entire region in a succinct way. + // Here we assume that the prototype of "test.special.op" can be trivially + // derived while parsing it back. + if (innerOp.getName().getStringRef().equals("test.special.op")) { + p << " start test.special.op end"; + } else { + p << " ("; + p.printRegion(getRegion()); + p << ")"; + } + + p << " : "; + p.printFunctionalType(*this); +} + +//===----------------------------------------------------------------------===// +// Test PolyForOp - parse list of region arguments. +//===----------------------------------------------------------------------===// + +ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector ivsInfo; + // Parse list of region arguments without a delimiter. + if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) + return failure(); + + // Parse the body region. + Region *body = result.addRegion(); + for (auto &iv : ivsInfo) + iv.type = parser.getBuilder().getIndexType(); + return parser.parseRegion(*body, ivsInfo); +} + +void PolyForOp::print(OpAsmPrinter &p) { + p << " "; + llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) { + p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true); + }); + p << " "; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +void PolyForOp::getAsmBlockArgumentNames(Region ®ion, + OpAsmSetValueNameFn setNameFn) { + auto arrayAttr = getOperation()->getAttrOfType("arg_names"); + if (!arrayAttr) + return; + auto args = getRegion().front().getArguments(); + auto e = std::min(arrayAttr.size(), args.size()); + for (unsigned i = 0; i < e; ++i) { + if (auto strAttr = dyn_cast(arrayAttr[i])) + setNameFn(args[i], strAttr.getValue()); + } +} + +//===----------------------------------------------------------------------===// +// TestAttrWithLoc - parse/printOptionalLocationSpecifier +//===----------------------------------------------------------------------===// + +static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) { + std::optional result; + SMLoc sourceLoc = p.getCurrentLocation(); + if (p.parseOptionalLocationSpecifier(result)) + return failure(); + if (result) + loc = *result; + else + loc = p.getEncodedSourceLoc(sourceLoc); + return success(); +} + +static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { + p.printOptionalLocationSpecifier(cast(loc)); +} + +#define GET_OP_CLASSES +#include "TestOpsSyntax.cpp.inc" + +void TestDialect::registerOpsSyntax() { + addOperations< +#define GET_OP_LIST +#include "TestOpsSyntax.cpp.inc" + >(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td @@ -0,0 +1,741 @@ + +//===-- TestOpsSyntax.td - Operations for testing syntax ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_OPS_SYNTAX +#define TEST_OPS_SYNTAX + +include "TestAttrDefs.td" +include "TestDialect.td" +include "TestTypeDefs.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/OpBase.td" + +class TEST_Op traits = []> : + Op; + +def WrappingRegionOp : TEST_Op<"wrapping_region", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let summary = "wrapping region operation"; + let description = [{ + Test op wrapping another op in a region, to test calling + parseGenericOperation from the custom parser. + }]; + + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$region); + let hasCustomAssemblyFormat = 1; +} + +def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let summary = "pretty_printed_region operation"; + let description = [{ + Test-op can be printed either in a "pretty" or "non-pretty" way based on + some criteria. The custom parser parsers both the versions while testing + APIs: parseCustomOperationName & parseGenericOperationAfterOpName. + }]; + let arguments = (ins + AnyType:$input1, + AnyType:$input2 + ); + + let results = (outs AnyType); + let regions = (region SizedRegion<1>:$region); + let hasCustomAssemblyFormat = 1; +} + +def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> { + let summary = "polyfor operation"; + let description = [{ + Test op with multiple region arguments, each argument of index type. + }]; + let extraClassDeclaration = [{ + void getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn); + }]; + let regions = (region SizedRegion<1>:$region); + let hasCustomAssemblyFormat = 1; +} + +def TestAttrWithLoc : TEST_Op<"attr_with_loc"> { + let summary = "op's attribute has a location"; + let arguments = (ins AnyAttr:$loc, AnyAttr:$value); + let assemblyFormat = "`(` $value `` custom($loc) `)` attr-dict"; +} + +// ----- + +// This is used to test that the fallback for a custom op's parser and printer +// is the dialect parser and printer hooks. +def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; + +// Ops related to OIList primitive +def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { + let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword, + UnitAttr:$diffNameUnitAttrKeyword); + let assemblyFormat = [{ + oilist( `keyword` $keyword + | `otherKeyword` $otherKeyword + | `thirdKeyword` $diffNameUnitAttrKeyword) attr-dict + }]; +} + +def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { + let arguments = (ins Optional:$arg0, + Optional:$arg1, + Optional:$arg2); + let assemblyFormat = [{ + oilist( `keyword` $arg0 `:` type($arg0) + | `otherKeyword` $arg1 `:` type($arg1) + | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict + }]; +} + +def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$arg0, + Variadic:$arg1, + Variadic:$arg2); + let assemblyFormat = [{ + oilist( `keyword` `(` $arg0 `:` type($arg0) `)` + | `otherKeyword` `(` $arg1 `:` type($arg1) `)` + | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict + }]; +} + +def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$arg0, + Optional:$optOperand, + UnitAttr:$nowait); + let assemblyFormat = [{ + oilist( `private` `(` $arg0 `:` type($arg0) `)` + | `reduction` custom($optOperand) + | `nowait` $nowait + ) attr-dict + }]; +} + +def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> { + let assemblyFormat = [{ + oilist( `foo` | `bar` ) `buzz` attr-dict + }]; +} + +def TestEllipsisOp : TEST_Op<"ellipsis"> { + let arguments = (ins Variadic:$operands, UnitAttr:$variadic); + let assemblyFormat = [{ + `(` $operands (`...` $variadic^)? `)` attr-dict `:` type($operands) `...` + }]; +} + +def ElseAnchorOp : TEST_Op<"else_anchor"> { + let arguments = (ins Optional:$a); + let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict"; +} + +// This is used to test that the default dialect is not elided when printing an +// op with dots in the name to avoid parsing ambiguity. +def OpWithDotInNameOp : TEST_Op<"op.with_dot_in_name"> { + let assemblyFormat = "attr-dict"; +} + +// -------------- + +//===----------------------------------------------------------------------===// +// Test Op Asm Format +//===----------------------------------------------------------------------===// + +def FormatLiteralOp : TEST_Op<"format_literal_op"> { + let assemblyFormat = [{ + `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` + `?` `+` `*` `{` `\n` `}` attr-dict + }]; +} + +// Test that we elide attributes that are within the syntax. +def FormatAttrOp : TEST_Op<"format_attr_op"> { + let arguments = (ins I64Attr:$attr); + let assemblyFormat = "$attr attr-dict"; +} + +// Test that we elide optional attributes that are within the syntax. +def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> { + let arguments = (ins OptionalAttr:$opt_attr); + let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict"; +} +def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> { + let arguments = (ins OptionalAttr:$opt_attr); + let assemblyFormat = "($opt_attr^)? attr-dict"; +} + +// Test that we format symbol name attributes properly. +def FormatSymbolNameAttrOp : TEST_Op<"format_symbol_name_attr_op"> { + let arguments = (ins SymbolNameAttr:$attr); + let assemblyFormat = "$attr attr-dict"; +} + +// Test that we format optional symbol name attributes properly. +def FormatOptSymbolNameAttrOp : TEST_Op<"format_opt_symbol_name_attr_op"> { + let arguments = (ins OptionalAttr:$opt_attr); + let assemblyFormat = "($opt_attr^)? attr-dict"; +} + +// Test that we format optional symbol reference attributes properly. +def FormatOptSymbolRefAttrOp : TEST_Op<"format_opt_symbol_ref_attr_op"> { + let arguments = (ins OptionalAttr:$opt_attr); + let assemblyFormat = "($opt_attr^)? attr-dict"; +} + +// Test that we elide attributes that are within the syntax. +def FormatAttrDictWithKeywordOp : TEST_Op<"format_attr_dict_w_keyword"> { + let arguments = (ins I64Attr:$attr, OptionalAttr:$opt_attr); + let assemblyFormat = "attr-dict-with-keyword"; +} + +// Test that we don't need to provide types in the format if they are buildable. +def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> { + let arguments = (ins I64:$buildable); + let results = (outs I64:$buildable_res); + let assemblyFormat = "$buildable attr-dict"; +} + +// Test various mixings of region formatting. +class FormatRegionBase + : TEST_Op<"format_region_" # suffix # "_op"> { + let regions = (region AnyRegion:$region); + let assemblyFormat = fmt; +} +def FormatRegionAOp : FormatRegionBase<"a", [{ + regions attr-dict +}]>; +def FormatRegionBOp : FormatRegionBase<"b", [{ + $region attr-dict +}]>; +def FormatRegionCOp : FormatRegionBase<"c", [{ + (`region` $region^)? attr-dict +}]>; +class FormatVariadicRegionBase + : TEST_Op<"format_variadic_region_" # suffix # "_op"> { + let regions = (region VariadicRegion:$regions); + let assemblyFormat = fmt; +} +def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{ + $regions attr-dict +}]>; +def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{ + ($regions^ `found_regions`)? attr-dict +}]>; +class FormatRegionImplicitTerminatorBase + : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op", + [SingleBlockImplicitTerminator<"TestReturnOp">]> { + let regions = (region AnyRegion:$region); + let assemblyFormat = fmt; +} +def FormatFormatRegionImplicitTerminatorAOp + : FormatRegionImplicitTerminatorBase<"a", [{ + $region attr-dict +}]>; + +// Test various mixings of result type formatting. +class FormatResultBase + : TEST_Op<"format_result_" # suffix # "_op"> { + let results = (outs I64:$buildable_res, AnyMemRef:$result); + let assemblyFormat = fmt; +} +def FormatResultAOp : FormatResultBase<"a", [{ + type($result) attr-dict +}]>; +def FormatResultBOp : FormatResultBase<"b", [{ + type(results) attr-dict +}]>; +def FormatResultCOp : FormatResultBase<"c", [{ + functional-type($buildable_res, $result) attr-dict +}]>; + +def FormatVariadicResult : TEST_Op<"format_variadic_result"> { + let results = (outs Variadic:$result); + let assemblyFormat = [{ `:` type($result) attr-dict}]; +} + +def FormatMultipleVariadicResults : TEST_Op<"format_multiple_variadic_results", + [AttrSizedResultSegments]> { + let results = (outs Variadic:$result0, Variadic:$result1); + let assemblyFormat = [{ + `:` `(` type($result0) `)` `,` `(` type($result1) `)` attr-dict + }]; +} + +// Test various mixings of operand type formatting. +class FormatOperandBase + : TEST_Op<"format_operand_" # suffix # "_op"> { + let arguments = (ins I64:$buildable, AnyMemRef:$operand); + let assemblyFormat = fmt; +} + +def FormatOperandAOp : FormatOperandBase<"a", [{ + operands `:` type(operands) attr-dict +}]>; +def FormatOperandBOp : FormatOperandBase<"b", [{ + operands `:` type($operand) attr-dict +}]>; +def FormatOperandCOp : FormatOperandBase<"c", [{ + $buildable `,` $operand `:` type(operands) attr-dict +}]>; +def FormatOperandDOp : FormatOperandBase<"d", [{ + $buildable `,` $operand `:` type($operand) attr-dict +}]>; +def FormatOperandEOp : FormatOperandBase<"e", [{ + $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict +}]>; + +def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { + let successors = (successor VariadicSuccessor:$targets); + let assemblyFormat = "$targets attr-dict"; +} + +def FormatVariadicOperand : TEST_Op<"format_variadic_operand"> { + let arguments = (ins Variadic:$operand); + let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; +} +def FormatVariadicOfVariadicOperand + : TEST_Op<"format_variadic_of_variadic_operand"> { + let arguments = (ins + VariadicOfVariadic:$operand, + DenseI32ArrayAttr:$operand_segments + ); + let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; +} + +def FormatMultipleVariadicOperands : + TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$operand0, Variadic:$operand1); + let assemblyFormat = [{ + ` ` `(` $operand0 `)` `,` `(` $operand1 `:` type($operand1) `)` attr-dict + }]; +} + +// Test various mixings of optional operand and result type formatting. +class FormatOptionalOperandResultOpBase + : TEST_Op<"format_optional_operand_result_" # suffix # "_op", + [AttrSizedOperandSegments]> { + let arguments = (ins Optional:$optional, Variadic:$variadic); + let results = (outs Optional:$optional_res); + let assemblyFormat = fmt; +} + +def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{ + `(` $optional `:` type($optional) `)` `:` type($optional_res) + (`[` $variadic^ `]`)? attr-dict +}]>; + +def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{ + (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res) + (`[` $variadic^ `]`)? attr-dict +}]>; + +// Test optional result type formatting. +class FormatOptionalResultOpBase + : TEST_Op<"format_optional_result_" # suffix # "_op", + [AttrSizedResultSegments]> { + let results = (outs Optional:$optional, Variadic:$variadic); + let assemblyFormat = fmt; +} +def FormatOptionalResultAOp : FormatOptionalResultOpBase<"a", [{ + (`:` type($optional)^ `->` type($variadic))? attr-dict +}]>; + +def FormatOptionalResultBOp : FormatOptionalResultOpBase<"b", [{ + (`:` type($optional) `->` type($variadic)^)? attr-dict +}]>; + +def FormatOptionalResultCOp : FormatOptionalResultOpBase<"c", [{ + (`:` functional-type($optional, $variadic)^)? attr-dict +}]>; + +def FormatOptionalResultDOp + : TEST_Op<"format_optional_result_d_op" > { + let results = (outs Optional:$optional); + let assemblyFormat = "(`:` type($optional)^)? attr-dict"; +} + +def FormatTwoVariadicOperandsNoBuildableTypeOp + : TEST_Op<"format_two_variadic_operands_no_buildable_type_op", + [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$a, + Variadic:$b); + let assemblyFormat = [{ + `(` $a `:` type($a) `)` `->` `(` $b `:` type($b) `)` attr-dict + }]; +} + +def FormatInferVariadicTypeFromNonVariadic + : TEST_Op<"format_infer_variadic_type_from_non_variadic", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic:$args); + let results = (outs AnyType:$result); + let assemblyFormat = "operands attr-dict `:` type($result)"; +} + +def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> { + let arguments = (ins UnitAttr:$is_optional); + let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict"; +} + +def FormatOptionalUnitAttrNoElide + : TEST_Op<"format_optional_unit_attribute_no_elide"> { + let arguments = (ins UnitAttr:$is_optional); + let assemblyFormat = "($is_optional^)? attr-dict"; +} + +def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> { + let arguments = (ins OptionalAttr:$attr); + let assemblyFormat = "($attr^)? attr-dict"; +} + +def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> { + let arguments = (ins DefaultValuedStrAttr:$str, + DefaultValuedStrAttr:$sym, + DefaultValuedAttr:$e); + let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict"; +} + +def FormatOptionalWithElse : TEST_Op<"format_optional_else"> { + let arguments = (ins UnitAttr:$isFirstBranchPresent); + let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict"; +} + +def FormatCompoundAttr : TEST_Op<"format_compound_attr"> { + let arguments = (ins CompoundAttrA:$compound); + let assemblyFormat = "$compound attr-dict-with-keyword"; +} + +def FormatNestedAttr : TEST_Op<"format_nested_attr"> { + let arguments = (ins CompoundAttrNested:$nested); + let assemblyFormat = "$nested attr-dict-with-keyword"; +} + +def FormatNestedCompoundAttr : TEST_Op<"format_cpmd_nested_attr"> { + let arguments = (ins CompoundNestedOuter:$nested); + let assemblyFormat = "`nested` $nested attr-dict-with-keyword"; +} + +def FormatMaybeEmptyType : TEST_Op<"format_maybe_empty_type"> { + let arguments = (ins TestTypeOptionalValueType:$in); + let assemblyFormat = "$in `:` type($in) attr-dict"; +} + +def FormatQualifiedCompoundAttr : TEST_Op<"format_qual_cpmd_nested_attr"> { + let arguments = (ins CompoundNestedOuter:$nested); + let assemblyFormat = "`nested` qualified($nested) attr-dict-with-keyword"; +} + +def FormatNestedType : TEST_Op<"format_cpmd_nested_type"> { + let arguments = (ins CompoundNestedOuterType:$nested); + let assemblyFormat = "$nested `nested` type($nested) attr-dict-with-keyword"; +} + +def FormatQualifiedNestedType : TEST_Op<"format_qual_cpmd_nested_type"> { + let arguments = (ins CompoundNestedOuterType:$nested); + let assemblyFormat = "$nested `nested` qualified(type($nested)) attr-dict-with-keyword"; +} + +//===----------------------------------------------------------------------===// +// Custom Directives + +def FormatCustomDirectiveOperands + : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> { + let arguments = (ins I64:$operand, Optional:$optOperand, + Variadic:$varOperands); + let assemblyFormat = [{ + custom( + $operand, $optOperand, $varOperands + ) + attr-dict + }]; +} + +def FormatCustomDirectiveOperandsAndTypes + : TEST_Op<"format_custom_directive_operands_and_types", + [AttrSizedOperandSegments]> { + let arguments = (ins AnyType:$operand, Optional:$optOperand, + Variadic:$varOperands); + let assemblyFormat = [{ + custom( + $operand, $optOperand, $varOperands, + type($operand), type($optOperand), type($varOperands) + ) + attr-dict + }]; +} + +def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> { + let regions = (region AnyRegion:$region, VariadicRegion:$other_regions); + let assemblyFormat = [{ + custom( + $region, $other_regions + ) + attr-dict + }]; +} + +def FormatCustomDirectiveResults + : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> { + let results = (outs AnyType:$result, Optional:$optResult, + Variadic:$varResults); + let assemblyFormat = [{ + custom( + type($result), type($optResult), type($varResults) + ) + attr-dict + }]; +} + +def FormatCustomDirectiveResultsWithTypeRefs + : TEST_Op<"format_custom_directive_results_with_type_refs", + [AttrSizedResultSegments]> { + let results = (outs AnyType:$result, Optional:$optResult, + Variadic:$varResults); + let assemblyFormat = [{ + custom( + type($result), type($optResult), type($varResults) + ) + custom( + ref(type($result)), ref(type($optResult)), ref(type($varResults)) + ) + attr-dict + }]; +} + +def FormatCustomDirectiveWithOptionalOperandRef + : TEST_Op<"format_custom_directive_with_optional_operand_ref"> { + let arguments = (ins Optional:$optOperand); + let assemblyFormat = [{ + ($optOperand^)? `:` + custom(ref($optOperand)) + attr-dict + }]; +} + +def FormatCustomDirectiveSuccessors + : TEST_Op<"format_custom_directive_successors", [Terminator]> { + let successors = (successor AnySuccessor:$successor, + VariadicSuccessor:$successors); + let assemblyFormat = [{ + custom( + $successor, $successors + ) + attr-dict + }]; +} + +def FormatCustomDirectiveAttributes + : TEST_Op<"format_custom_directive_attributes"> { + let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); + let assemblyFormat = [{ + custom( + $attr, $optAttr + ) + attr-dict + }]; +} + +def FormatCustomDirectiveSpacing + : TEST_Op<"format_custom_directive_spacing"> { + let arguments = (ins StrAttr:$attr1, StrAttr:$attr2); + let assemblyFormat = [{ + custom($attr1) + custom($attr2) + attr-dict + }]; +} + +def FormatCustomDirectiveAttrDict + : TEST_Op<"format_custom_directive_attrdict"> { + let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); + let assemblyFormat = [{ + custom( attr-dict ) + }]; +} + +def FormatLiteralFollowingOptionalGroup + : TEST_Op<"format_literal_following_optional_group"> { + let arguments = (ins TypeAttr:$type, OptionalAttr:$value); + let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict"; +} + +//===----------------------------------------------------------------------===// +// AllTypesMatch type inference + +def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [ + AllTypesMatch<["value1", "value2", "result"]> + ]> { + let arguments = (ins AnyType:$value1, AnyType:$value2); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)"; +} + +def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [ + AllTypesMatch<["value1", "value2", "result"]> + ]> { + let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value1 `,` $value2"; +} + +//===----------------------------------------------------------------------===// +// TypesMatchWith type inference + +def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ + TypesMatchWith<"result type matches operand", "value", "result", "$_self"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + +def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [ + RangedTypesMatchWith<"result type matches operand", "value", "result", + "llvm::make_range($_self.begin(), $_self.end())"> + ]> { + let arguments = (ins Variadic:$value); + let results = (outs Variadic:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + +def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ + TypesMatchWith<"result type matches constant", "value", "result", "$_self"> + ]> { + let arguments = (ins TypedAttrInterface:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value"; +} + +def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [ + TypesMatchWith<"tuple result type matches operand type", "value", "result", + "::mlir::TupleType::get($_ctxt, $_self)"> + ]> { + let arguments = (ins AnyType:$value); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface type inference in assembly format + +def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { + let results = (outs AnyType); + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); + return ::mlir::success(); + } + }]; +} + +// Check that formatget supports DeclareOpInterfaceMethods. +def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods]> { + let results = (outs AnyType); + let assemblyFormat = "attr-dict"; +} + +// Base class for testing mixing allOperandTypes, allOperands, and +// inferResultTypes. +class FormatInferAllTypesBaseOp traits = []> + : TEST_Op { + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$outs); + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::TypeRange operandTypes = operands.getTypes(); + inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end()); + return ::mlir::success(); + } + }]; +} + +// Test inferReturnTypes is called when allOperandTypes and allOperands is true. +def FormatInferTypeAllOperandsAndTypesOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> { + let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; +} + +// Test inferReturnTypes is called when allOperandTypes is true and there is one +// ODS operand. +def FormatInferTypeAllOperandsAndTypesOneOperandOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> { + let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)"; +} + +// Test inferReturnTypes is called when allOperandTypes is true and there are +// more than one ODS operands. +def FormatInferTypeAllOperandsAndTypesTwoOperandsOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands", + [SameVariadicOperandSize]> { + let arguments = (ins Variadic:$args0, Variadic:$args1); + let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)"; +} + +// Test inferReturnTypes is called when allOperands is true and operand types +// are separately specified. +def FormatInferTypeAllTypesOp + : FormatInferAllTypesBaseOp<"format_infer_type_all_types"> { + let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)"; +} + +// Test inferReturnTypes coupled with regions. +def FormatInferTypeRegionsOp + : TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> { + let results = (outs Variadic:$outs); + let regions = (region AnyRegion:$region); + let assemblyFormat = "$region attr-dict"; + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + if (regions.empty()) + return ::mlir::failure(); + auto types = regions.front()->getArgumentTypes(); + inferredReturnTypes.assign(types.begin(), types.end()); + return ::mlir::success(); + } + }]; +} + +// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes). +def FormatInferTypeVariadicOperandsOp + : TEST_Op<"format_infer_type_variadic_operands", + [InferTypeOpInterface, AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$a, Variadic:$b); + let results = (outs Variadic:$outs); + let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict"; + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + FormatInferTypeVariadicOperandsOpAdaptor adaptor( + operands, attributes, *properties.as(), {}); + auto aTypes = adaptor.getA().getTypes(); + auto bTypes = adaptor.getB().getTypes(); + inferredReturnTypes.append(aTypes.begin(), aTypes.end()); + inferredReturnTypes.append(bTypes.begin(), bTypes.end()); + return ::mlir::success(); + } + }]; +} + +#endif // TEST_OPS_SYNTAX diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp --- a/mlir/unittests/IR/AdaptorTest.cpp +++ b/mlir/unittests/IR/AdaptorTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOpsSyntax.h" #include "gmock/gmock.h" #include "gtest/gtest.h"