diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp --- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp +++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "TestTypes.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp --- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" #include "mlir/IR/Builders.h" diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -26,7 +26,7 @@ set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) -mlir_tablegen(TestOps.cpp.inc -gen-op-defs) +mlir_tablegen(TestOps.cpp.inc -gen-op-defs -op-shard-count=4) mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test) mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test) mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) @@ -40,7 +40,13 @@ add_mlir_library(MLIRTestDialect TestAttributes.cpp TestDialect.cpp + TestFormatUtils.cpp TestInterfaces.cpp + TestOpDefs.cpp + TestOps0.cpp + TestOps1.cpp + TestOps2.cpp + TestOps3.cpp TestPatterns.cpp TestTraits.cpp TestTypes.cpp diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -24,6 +24,7 @@ #include "TestAttrInterfaces.h.inc" #include "TestOpEnums.h.inc" +#include "TestOpStructs.h.inc" #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -13,16 +13,8 @@ #include "TestAttributes.h" #include "TestDialect.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/ExtensibleDialect.h" -#include "mlir/IR/Types.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/ADT/bit.h" -#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace test; @@ -199,7 +191,8 @@ //===----------------------------------------------------------------------===// #include "TestAttrInterfaces.cpp.inc" - +#include "TestOpEnums.cpp.inc" +#include "TestOpStructs.cpp.inc" #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,46 +14,19 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H -#include "TestAttributes.h" -#include "TestInterfaces.h" -#include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/DLTI/Traits.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/ExtensibleDialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/RegionKindInterface.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/CopyOpInterface.h" -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" namespace mlir { -class DLTIDialect; class RewritePatternSet; -} // namespace mlir +} // end namespace mlir -#include "TestOpInterfaces.h.inc" -#include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" -#define GET_OP_CLASSES -#include "TestOps.h.inc" - namespace test { -void registerTestDialect(::mlir::DialectRegistry ®istry); -void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns); -} // namespace test +void registerTestDialect(mlir::DialectRegistry ®istry); +void populateTestReductionPatterns(mlir::RewritePatternSet &patterns); +} // end namespace test #endif // MLIR_TESTDIALECT_H diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -7,33 +7,16 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" -#include "TestAttributes.h" -#include "TestInterfaces.h" +#include "TestOps.h" #include "TestTypes.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/ExtensibleDialect.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" -#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSwitch.h" -// Include this before the using namespace lines below to -// test that we don't have namespace dependencies. +// Include this before the using namespace lines below to test that we don't +// have namespace dependencies. #include "TestOpsDialect.cpp.inc" using namespace mlir; @@ -49,15 +32,8 @@ namespace { -/// Testing the correctness of some traits. -static_assert( - llvm::is_detected::value, - "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); -static_assert(OpTrait::hasSingleBlockImplicitTerminator< - SingleBlockImplicitTerminatorOp>::value, - "hasSingleBlockImplicitTerminator does not match " - "SingleBlockImplicitTerminatorOp"); +//===----------------------------------------------------------------------===// +// OpAsmDialectInterface // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { @@ -110,6 +86,9 @@ } }; +//===----------------------------------------------------------------------===// +// DialectFoldInterface + struct TestDialectFoldInterface : public DialectFoldInterface { using DialectFoldInterface::DialectFoldInterface; @@ -122,6 +101,9 @@ } }; +//===----------------------------------------------------------------------===// +// DialectInlinerInterface + /// This class defines the interface for handling inlining with standard /// operations. struct TestInlinerInterface : public DialectInlinerInterface { @@ -129,7 +111,6 @@ //===--------------------------------------------------------------------===// // Analysis Hooks - //===--------------------------------------------------------------------===// bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { @@ -154,7 +135,6 @@ //===--------------------------------------------------------------------===// // Transformation Hooks - //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. @@ -203,8 +183,10 @@ } }; +//===----------------------------------------------------------------------===// +// DialectReductionPatternInterface + struct TestReductionPatternInterface : public DialectReductionPatternInterface { -public: TestReductionPatternInterface(Dialect *dialect) : DialectReductionPatternInterface(dialect) {} @@ -213,6 +195,32 @@ } }; +//===----------------------------------------------------------------------===// +// TestEffectOpInterface + +// This is the implementation of a dialect fallback for `TestEffectOpInterface`. +struct TestOpEffectInterfaceFallback + : public TestEffectOpInterface::FallbackModel< + TestOpEffectInterfaceFallback> { + static bool classof(Operation *op) { + bool isSupportedOp = + op->getName().getStringRef() == "test.unregistered_side_effect_op"; + assert(isSupportedOp && "Unexpected dispatch"); + return isSupportedOp; + } + + void + getEffects(Operation *op, + SmallVectorImpl> + &effects) const { + auto effectsAttr = op->getAttrOfType("effect_parameter"); + if (!effectsAttr) + return; + + effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -271,36 +279,13 @@ // TestDialect //===----------------------------------------------------------------------===// -static void testSideEffectOpGetEffect( - Operation *op, - SmallVectorImpl> &effects); - -// This is the implementation of a dialect fallback for `TestEffectOpInterface`. -struct TestOpEffectInterfaceFallback - : public TestEffectOpInterface::FallbackModel< - TestOpEffectInterfaceFallback> { - static bool classof(Operation *op) { - bool isSupportedOp = - op->getName().getStringRef() == "test.unregistered_side_effect_op"; - assert(isSupportedOp && "Unexpected dispatch"); - return isSupportedOp; - } - - void - getEffects(Operation *op, - SmallVectorImpl> - &effects) const { - testSideEffectOpGetEffect(op, effects); - } -}; - void TestDialect::initialize() { registerAttributes(); registerTypes(); - addOperations< -#define GET_OP_LIST -#include "TestOps.cpp.inc" - >(); + registerOps0(); + registerOps1(); + registerOps2(); + registerOps3(); registerDynamicOp(getDynamicGenericOp(this)); registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); @@ -318,31 +303,25 @@ fallbackEffectOpInterfaces); } +//===----------------------------------------------------------------------===// +// TestDialect Hooks + Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); } -::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( - ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, - ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, - ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); - return ::mlir::success(); -} - -void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, +void *TestDialect::getRegisteredInterfaceForOp(TypeID interfaceID, OperationName opName) { if (opName.getIdentifier() == "test.unregistered_side_effect_op" && - typeID == TypeID::get()) + interfaceID == TypeID::get()) return fallbackEffectOpInterfaces; return nullptr; } LogicalResult TestDialect::verifyOperationAttribute(Operation *op, - NamedAttribute namedAttr) { - if (namedAttr.getName() == "test.invalid_attr") + NamedAttribute attribute) { + if (attribute.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } @@ -350,8 +329,8 @@ LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, unsigned argIndex, - NamedAttribute namedAttr) { - if (namedAttr.getName() == "test.invalid_attr") + NamedAttribute attribute) { + if (attribute.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } @@ -359,8 +338,8 @@ LogicalResult TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, unsigned resultIndex, - NamedAttribute namedAttr) { - if (namedAttr.getName() == "test.invalid_attr") + NamedAttribute attribute) { + if (attribute.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } @@ -401,37 +380,6 @@ return {}; } -//===----------------------------------------------------------------------===// -// TestBranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { - assert(index == 0 && "invalid successor index"); - return SuccessorOperands(getTargetOperandsMutable()); -} - -//===----------------------------------------------------------------------===// -// TestProducingBranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { - assert(index <= 1 && "invalid successor index"); - if (index == 1) - return SuccessorOperands(getFirstOperandsMutable()); - return SuccessorOperands(getSecondOperandsMutable()); -} - -//===----------------------------------------------------------------------===// -// TestProducingBranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { - assert(index <= 1 && "invalid successor index"); - if (index == 0) - return SuccessorOperands(0, getSuccessOperandsMutable()); - return SuccessorOperands(1, getErrorOperandsMutable()); -} - //===----------------------------------------------------------------------===// // TestDialectCanonicalizerOp //===----------------------------------------------------------------------===// @@ -448,1025 +396,3 @@ RewritePatternSet &results) const { results.add(&dialectCanonicalizationPattern); } - -//===----------------------------------------------------------------------===// -// TestCallOp -//===----------------------------------------------------------------------===// - -LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // Check that the callee attribute was specified. - auto fnAttr = (*this)->getAttrOfType("callee"); - if (!fnAttr) - return emitOpError("requires a 'callee' symbol reference attribute"); - if (!symbolTable.lookupNearestSymbolFrom(*this, fnAttr)) - return emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; - return success(); -} - -//===----------------------------------------------------------------------===// -// TestFoldToCallOp -//===----------------------------------------------------------------------===// - -namespace { -struct FoldToCallOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FoldToCallOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, TypeRange(), - op.getCalleeAttr(), ValueRange()); - return success(); - } -}; -} // namespace - -void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//===----------------------------------------------------------------------===// -// Test Format* operations -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// Parsing - -static ParseResult parseCustomOptionalOperand( - OpAsmParser &parser, Optional &optOperand) { - if (succeeded(parser.parseOptionalLParen())) { - optOperand.emplace(); - if (parser.parseOperand(*optOperand) || parser.parseRParen()) - return failure(); - } - return success(); -} - -static ParseResult parseCustomDirectiveOperands( - OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, - Optional &optOperand, - SmallVectorImpl &varOperands) { - if (parser.parseOperand(operand)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - optOperand.emplace(); - if (parser.parseOperand(*optOperand)) - return failure(); - } - if (parser.parseArrow() || parser.parseLParen() || - parser.parseOperandList(varOperands) || parser.parseRParen()) - return failure(); - return success(); -} -static ParseResult -parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, - Type &optOperandType, - SmallVectorImpl &varOperandTypes) { - if (parser.parseColon()) - return failure(); - - if (parser.parseType(operandType)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseType(optOperandType)) - return failure(); - } - if (parser.parseArrow() || parser.parseLParen() || - parser.parseTypeList(varOperandTypes) || parser.parseRParen()) - return failure(); - return success(); -} -static ParseResult -parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, - Type optOperandType, - const SmallVectorImpl &varOperandTypes) { - if (parser.parseKeyword("type_refs_capture")) - return failure(); - - Type operandType2, optOperandType2; - SmallVector varOperandTypes2; - if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, - varOperandTypes2)) - return failure(); - - if (operandType != operandType2 || optOperandType != optOperandType2 || - varOperandTypes != varOperandTypes2) - return failure(); - - return success(); -} -static ParseResult parseCustomDirectiveOperandsAndTypes( - OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, - Optional &optOperand, - SmallVectorImpl &varOperands, - Type &operandType, Type &optOperandType, - SmallVectorImpl &varOperandTypes) { - if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || - parseCustomDirectiveResults(parser, operandType, optOperandType, - varOperandTypes)) - return failure(); - return success(); -} -static ParseResult parseCustomDirectiveRegions( - OpAsmParser &parser, Region ®ion, - SmallVectorImpl> &varRegions) { - if (parser.parseRegion(region)) - return failure(); - if (failed(parser.parseOptionalComma())) - return success(); - std::unique_ptr varRegion = std::make_unique(); - if (parser.parseRegion(*varRegion)) - return failure(); - varRegions.emplace_back(std::move(varRegion)); - return success(); -} -static ParseResult -parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, - SmallVectorImpl &varSuccessors) { - if (parser.parseSuccessor(successor)) - return failure(); - if (failed(parser.parseOptionalComma())) - return success(); - Block *varSuccessor; - if (parser.parseSuccessor(varSuccessor)) - return failure(); - varSuccessors.append(2, varSuccessor); - return success(); -} -static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, - IntegerAttr &attr, - IntegerAttr &optAttr) { - if (parser.parseAttribute(attr)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(optAttr)) - return failure(); - } - return success(); -} - -static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, - NamedAttrList &attrs) { - return parser.parseOptionalAttrDict(attrs); -} -static ParseResult parseCustomDirectiveOptionalOperandRef( - OpAsmParser &parser, Optional &optOperand) { - int64_t operandCount = 0; - if (parser.parseInteger(operandCount)) - return failure(); - bool expectedOptionalOperand = operandCount == 0; - return success(expectedOptionalOperand != optOperand.hasValue()); -} - -//===----------------------------------------------------------------------===// -// Printing - -static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, - Value optOperand) { - if (optOperand) - printer << "(" << optOperand << ") "; -} - -static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, - Value operand, Value optOperand, - OperandRange varOperands) { - printer << operand; - if (optOperand) - printer << ", " << optOperand; - printer << " -> (" << varOperands << ")"; -} -static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, - Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printer << " : " << operandType; - if (optOperandType) - printer << ", " << optOperandType; - printer << " -> (" << varOperandTypes << ")"; -} -static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, - Operation *op, Type operandType, - Type optOperandType, - TypeRange varOperandTypes) { - printer << " type_refs_capture "; - printCustomDirectiveResults(printer, op, operandType, optOperandType, - varOperandTypes); -} -static void printCustomDirectiveOperandsAndTypes( - OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, - OperandRange varOperands, Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); - printCustomDirectiveResults(printer, op, operandType, optOperandType, - varOperandTypes); -} -static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, - Region ®ion, - MutableArrayRef varRegions) { - printer.printRegion(region); - if (!varRegions.empty()) { - printer << ", "; - for (Region ®ion : varRegions) - printer.printRegion(region); - } -} -static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, - Block *successor, - SuccessorRange varSuccessors) { - printer << successor; - if (!varSuccessors.empty()) - printer << ", " << varSuccessors.front(); -} -static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, - Attribute attribute, - Attribute optAttribute) { - printer << attribute; - if (optAttribute) - printer << ", " << optAttribute; -} - -static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs) { - printer.printOptionalAttrDict(attrs.getValue()); -} - -static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, - Operation *op, - Value optOperand) { - printer << (optOperand ? "1" : "0"); -} - -//===----------------------------------------------------------------------===// -// Test IsolatedRegionOp - parse passthrough region arguments. -//===----------------------------------------------------------------------===// - -ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse the input operand. - OpAsmParser::Argument argInfo; - argInfo.type = parser.getBuilder().getIndexType(); - if (parser.parseOperand(argInfo.ssaName) || - parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) - return failure(); - - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); -} - -void IsolatedRegionOp::print(OpAsmPrinter &p) { - p << "test.isolated_region "; - p.printOperand(getOperand()); - p.shadowRegionArgs(getRegion(), getOperand()); - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -//===----------------------------------------------------------------------===// -// Test SSACFGRegionOp -//===----------------------------------------------------------------------===// - -RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { - return RegionKind::SSACFG; -} - -//===----------------------------------------------------------------------===// -// Test GraphRegionOp -//===----------------------------------------------------------------------===// - -RegionKind GraphRegionOp::getRegionKind(unsigned index) { - return RegionKind::Graph; -} - -//===----------------------------------------------------------------------===// -// Test AffineScopeOp -//===----------------------------------------------------------------------===// - -ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); -} - -void AffineScopeOp::print(OpAsmPrinter &p) { - p << "test.affine_scope "; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -//===----------------------------------------------------------------------===// -// Test parser. -//===----------------------------------------------------------------------===// - -ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parser.parseOptionalColon()) - return success(); - uint64_t numResults; - if (parser.parseInteger(numResults)) - return failure(); - - IndexType type = parser.getBuilder().getIndexType(); - for (unsigned i = 0; i < numResults; ++i) - result.addTypes(type); - return success(); -} - -void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { - if (unsigned numResults = getNumResults()) - p << " : " << numResults; -} - -ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, - OperationState &result) { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return failure(); - result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); - return success(); -} - -void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } - -//===----------------------------------------------------------------------===// -// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. - -ParseResult WrappingRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parser.parseKeyword("wraps")) - return failure(); - - // Parse the wrapped op in a region - Region &body = *result.addRegion(); - body.push_back(new Block); - Block &block = body.back(); - Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); - if (!wrappedOp) - return failure(); - - // Create a return terminator in the inner region, pass as operand to the - // terminator the returned values from the wrapped operation. - SmallVector returnOperands(wrappedOp->getResults()); - OpBuilder builder(parser.getContext()); - builder.setInsertionPointToEnd(&block); - builder.create(wrappedOp->getLoc(), returnOperands); - - // Get the results type for the wrapping op from the terminator operands. - Operation &returnOp = body.back().back(); - result.types.append(returnOp.operand_type_begin(), - returnOp.operand_type_end()); - - // Use the location of the wrapped op for the "test.wrapping_region" op. - result.location = wrappedOp->getLoc(); - - return success(); -} - -void WrappingRegionOp::print(OpAsmPrinter &p) { - p << " wraps "; - p.printGenericOp(&getRegion().front().front()); -} - -//===----------------------------------------------------------------------===// -// Test PrettyPrintedRegionOp - exercising the following parser APIs -// parseGenericOperationAfterOpName -// parseCustomOperationName -//===----------------------------------------------------------------------===// - -ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - - SMLoc loc = parser.getCurrentLocation(); - Location currLocation = parser.getEncodedSourceLoc(loc); - - // Parse the operands. - SmallVector operands; - if (parser.parseOperandList(operands)) - return failure(); - - // Check if we are parsing the pretty-printed version - // test.pretty_printed_region start end : - // Else fallback to parsing the "non pretty-printed" version. - if (!succeeded(parser.parseOptionalKeyword("start"))) - return parser.parseGenericOperationAfterOpName( - result, llvm::makeArrayRef(operands)); - - FailureOr parseOpNameInfo = parser.parseCustomOperationName(); - if (failed(parseOpNameInfo)) - return failure(); - - StringAttr innerOpName = parseOpNameInfo->getIdentifier(); - - FunctionType opFntype; - Optional explicitLoc; - if (parser.parseKeyword("end") || parser.parseColon() || - parser.parseType(opFntype) || - parser.parseOptionalLocationSpecifier(explicitLoc)) - return failure(); - - // If location of the op is explicitly provided, then use it; Else use - // the parser's current location. - Location opLoc = explicitLoc.getValueOr(currLocation); - - // Derive the SSA-values for op's operands. - if (parser.resolveOperands(operands, opFntype.getInputs(), loc, - result.operands)) - return failure(); - - // Add a region for op. - Region ®ion = *result.addRegion(); - - // Create a basic-block inside op's region. - Block &block = region.emplaceBlock(); - - // Create and insert an "inner-op" operation in the block. - // Just for testing purposes, we can assume that inner op is a binary op with - // result and operand types all same as the test-op's first operand. - Type innerOpType = opFntype.getInput(0); - Value lhs = block.addArgument(innerOpType, opLoc); - Value rhs = block.addArgument(innerOpType, opLoc); - - OpBuilder builder(parser.getBuilder().getContext()); - builder.setInsertionPointToStart(&block); - - Operation *innerOp = - builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); - - // Insert a return statement in the block returning the inner-op's result. - builder.create(innerOp->getLoc(), innerOp->getResults()); - - // Populate the op operation-state with result-type and location. - result.addTypes(opFntype.getResults()); - result.location = innerOp->getLoc(); - - return success(); -} - -void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { - p << ' '; - p.printOperands(getOperands()); - - Operation &innerOp = getRegion().front().front(); - // Assuming that region has a single non-terminator inner-op, if the inner-op - // meets some criteria (which in this case is a simple one based on the name - // of inner-op), then we can print the entire region in a succinct way. - // Here we assume that the prototype of "special.op" can be trivially derived - // while parsing it back. - if (innerOp.getName().getStringRef().equals("special.op")) { - p << " start special.op end"; - } else { - p << " ("; - p.printRegion(getRegion()); - p << ")"; - } - - p << " : "; - p.printFunctionalType(*this); -} - -//===----------------------------------------------------------------------===// -// Test PolyForOp - parse list of region arguments. -//===----------------------------------------------------------------------===// - -ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector ivsInfo; - // Parse list of region arguments without a delimiter. - if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) - return failure(); - - // Parse the body region. - Region *body = result.addRegion(); - for (auto &iv : ivsInfo) - iv.type = parser.getBuilder().getIndexType(); - return parser.parseRegion(*body, ivsInfo); -} - -void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); } - -void PolyForOp::getAsmBlockArgumentNames(Region ®ion, - OpAsmSetValueNameFn setNameFn) { - auto arrayAttr = getOperation()->getAttrOfType("arg_names"); - if (!arrayAttr) - return; - auto args = getRegion().front().getArguments(); - auto e = std::min(arrayAttr.size(), args.size()); - for (unsigned i = 0; i < e; ++i) { - if (auto strAttr = arrayAttr[i].dyn_cast()) - setNameFn(args[i], strAttr.getValue()); - } -} - -//===----------------------------------------------------------------------===// -// Test removing op with inner ops. -//===----------------------------------------------------------------------===// - -namespace { -struct TestRemoveOpWithInnerOps - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } - - LogicalResult matchAndRewrite(TestOpWithRegionPattern op, - PatternRewriter &rewriter) const override { - rewriter.eraseOp(op); - return success(); - } -}; -} // namespace - -void TestOpWithRegionPattern::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { - return getOperand(); -} - -OpFoldResult TestOpConstant::fold(ArrayRef operands) { - return getValue(); -} - -LogicalResult TestOpWithVariadicResultsAndFolder::fold( - ArrayRef operands, SmallVectorImpl &results) { - for (Value input : this->getOperands()) { - results.push_back(input); - } - return success(); -} - -OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { - assert(operands.size() == 1); - if (operands.front()) { - (*this)->setAttr("attr", operands.front()); - return getResult(); - } - return {}; -} - -OpFoldResult TestPassthroughFold::fold(ArrayRef operands) { - return getOperand(); -} - -LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( - MLIRContext *, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType() != operands[1].getType()) { - return emitOptionalError(location, "operand type mismatch ", - operands[0].getType(), " vs ", - operands[1].getType()); - } - inferredReturnTypes.assign({operands[0].getType()}); - return success(); -} - -LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnShapes) { - // Create return type consisting of the last element of the first operand. - auto operandType = operands.front().getType(); - auto sval = operandType.dyn_cast(); - if (!sval) { - return emitOptionalError(location, "only shaped type operands allowed"); - } - int64_t dim = - sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; - auto type = IntegerType::get(context, 17); - inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); - return success(); -} - -LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, ValueRange operands, - llvm::SmallVectorImpl &shapes) { - shapes = SmallVector{ - builder.createOrFold(getLoc(), operands.front(), 0)}; - return success(); -} - -LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, ValueRange operands, - llvm::SmallVectorImpl &shapes) { - Location loc = getLoc(); - shapes.reserve(operands.size()); - for (Value operand : llvm::reverse(operands)) { - auto rank = operand.getType().cast().getRank(); - auto currShape = llvm::to_vector<4>( - llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { - return builder.createOrFold(loc, operand, dim); - })); - shapes.push_back(builder.create( - getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), - currShape)); - } - return success(); -} - -LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( - OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { - Location loc = getLoc(); - shapes.reserve(getNumOperands()); - for (Value operand : llvm::reverse(getOperands())) { - auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq( - 0, operand.getType().cast().getRank()), - [&](int64_t dim) -> Value { - return builder.createOrFold(loc, operand, dim); - })); - shapes.emplace_back(std::move(currShape)); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Test SideEffect interfaces -//===----------------------------------------------------------------------===// - -namespace { -/// A test resource for side effects. -struct TestResource : public SideEffects::Resource::Base { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) - - StringRef getName() final { return ""; } -}; -} // namespace - -static void testSideEffectOpGetEffect( - Operation *op, - SmallVectorImpl> - &effects) { - auto effectsAttr = op->getAttrOfType("effect_parameter"); - if (!effectsAttr) - return; - - effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); -} - -void SideEffectOp::getEffects( - SmallVectorImpl &effects) { - // Check for an effects attribute on the op instance. - ArrayAttr effectsAttr = (*this)->getAttrOfType("effects"); - if (!effectsAttr) - return; - - // If there is one, it is an array of dictionary attributes that hold - // information on the effects of this operation. - for (Attribute element : effectsAttr) { - DictionaryAttr effectElement = element.cast(); - - // Get the specific memory effect. - MemoryEffects::Effect *effect = - StringSwitch( - effectElement.get("effect").cast().getValue()) - .Case("allocate", MemoryEffects::Allocate::get()) - .Case("free", MemoryEffects::Free::get()) - .Case("read", MemoryEffects::Read::get()) - .Case("write", MemoryEffects::Write::get()); - - // Check for a non-default resource to use. - SideEffects::Resource *resource = SideEffects::DefaultResource::get(); - if (effectElement.get("test_resource")) - resource = TestResource::get(); - - // Check for a result to affect. - if (effectElement.get("on_result")) - effects.emplace_back(effect, getResult(), resource); - else if (Attribute ref = effectElement.get("on_reference")) - effects.emplace_back(effect, ref.cast(), resource); - else - effects.emplace_back(effect, resource); - } -} - -void SideEffectOp::getEffects( - SmallVectorImpl &effects) { - testSideEffectOpGetEffect(getOperation(), effects); -} - -//===----------------------------------------------------------------------===// -// StringAttrPrettyNameOp -//===----------------------------------------------------------------------===// - -// This op has fancy handling of its SSA result name. -ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, - OperationState &result) { - // Add the result types. - for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) - result.addTypes(parser.getBuilder().getIntegerType(32)); - - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) - return failure(); - - // If the attribute dictionary contains no 'names' attribute, infer it from - // the SSA name (if specified). - bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { - return attr.getName() == "names"; - }); - - // If there was no name specified, check to see if there was a useful name - // specified in the asm file. - if (hadNames || parser.getNumResults() == 0) - return success(); - - SmallVector names; - auto *context = result.getContext(); - - for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { - auto resultName = parser.getResultName(i); - StringRef nameStr; - if (!resultName.first.empty() && !isdigit(resultName.first[0])) - nameStr = resultName.first; - - names.push_back(nameStr); - } - - auto namesAttr = parser.getBuilder().getStrArrayAttr(names); - result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); - return success(); -} - -void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { - // Note that we only need to print the "name" attribute if the asmprinter - // result name disagrees with it. This can happen in strange cases, e.g. - // when there are conflicts. - bool namesDisagree = getNames().size() != getNumResults(); - - SmallString<32> resultNameStr; - for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { - resultNameStr.clear(); - llvm::raw_svector_ostream tmpStream(resultNameStr); - p.printOperand(getResult(i), tmpStream); - - auto expectedName = getNames()[i].dyn_cast(); - if (!expectedName || - tmpStream.str().drop_front() != expectedName.getValue()) { - namesDisagree = true; - } - } - - if (namesDisagree) - p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); - else - p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); -} - -// We set the SSA name in the asm syntax to the contents of the name -// attribute. -void StringAttrPrettyNameOp::getAsmResultNames( - function_ref setNameFn) { - - auto value = getNames(); - for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = value[i].dyn_cast()) - if (!str.getValue().empty()) - setNameFn(getResult(i), str.getValue()); -} - -void CustomResultsNameOp::getAsmResultNames( - function_ref setNameFn) { - ArrayAttr value = getNames(); - for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = value[i].dyn_cast()) - if (!str.getValue().empty()) - setNameFn(getResult(i), str.getValue()); -} - -//===----------------------------------------------------------------------===// -// ResultTypeWithTraitOp -//===----------------------------------------------------------------------===// - -LogicalResult ResultTypeWithTraitOp::verify() { - if ((*this)->getResultTypes()[0].hasTrait()) - return success(); - return emitError("result type should have trait 'TestTypeTrait'"); -} - -//===----------------------------------------------------------------------===// -// AttrWithTraitOp -//===----------------------------------------------------------------------===// - -LogicalResult AttrWithTraitOp::verify() { - if (getAttr().hasTrait()) - return success(); - return emitError("'attr' attribute should have trait 'TestAttrTrait'"); -} - -//===----------------------------------------------------------------------===// -// RegionIfOp -//===----------------------------------------------------------------------===// - -void RegionIfOp::print(OpAsmPrinter &p) { - p << " "; - p.printOperands(getOperands()); - p << ": " << getOperandTypes(); - p.printArrowTypeList(getResultTypes()); - p << " then "; - p.printRegion(getThenRegion(), - /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); - p << " else "; - p.printRegion(getElseRegion(), - /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); - p << " join "; - p.printRegion(getJoinRegion(), - /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); -} - -ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector operandInfos; - SmallVector operandTypes; - - result.regions.reserve(3); - Region *thenRegion = result.addRegion(); - Region *elseRegion = result.addRegion(); - Region *joinRegion = result.addRegion(); - - // Parse operand, type and arrow type lists. - if (parser.parseOperandList(operandInfos) || - parser.parseColonTypeList(operandTypes) || - parser.parseArrowTypeList(result.types)) - return failure(); - - // Parse all attached regions. - if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || - parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || - parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) - return failure(); - - return parser.resolveOperands(operandInfos, operandTypes, - parser.getCurrentLocation(), result.operands); -} - -OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { - assert(index < 2 && "invalid region index"); - return getOperands(); -} - -void RegionIfOp::getSuccessorRegions( - Optional index, ArrayRef operands, - SmallVectorImpl ®ions) { - // We always branch to the join region. - if (index.hasValue()) { - if (index.getValue() < 2) - regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); - else - regions.push_back(RegionSuccessor(getResults())); - return; - } - - // The then and else regions are the entry regions of this op. - regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); - regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); -} - -void RegionIfOp::getRegionInvocationBounds( - ArrayRef operands, - SmallVectorImpl &invocationBounds) { - // Each region is invoked at most once. - invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); -} - -//===----------------------------------------------------------------------===// -// AnyCondOp -//===----------------------------------------------------------------------===// - -void AnyCondOp::getSuccessorRegions(Optional index, - ArrayRef operands, - SmallVectorImpl ®ions) { - // The parent op branches into the only region, and the region branches back - // to the parent op. - if (index) - regions.emplace_back(&getRegion()); - else - regions.emplace_back(getResults()); -} - -void AnyCondOp::getRegionInvocationBounds( - ArrayRef operands, - SmallVectorImpl &invocationBounds) { - invocationBounds.emplace_back(1, 1); -} - -//===----------------------------------------------------------------------===// -// SingleNoTerminatorCustomAsmOp -//===----------------------------------------------------------------------===// - -ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, - OperationState &state) { - Region *body = state.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - return success(); -} - -void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { - printer.printRegion( - getRegion(), /*printEntryBlockArgs=*/false, - // This op has a single block without terminators. But explicitly mark - // as not printing block terminators for testing. - /*printBlockTerminators=*/false); -} - -//===----------------------------------------------------------------------===// -// TestVerifiersOp -//===----------------------------------------------------------------------===// - -LogicalResult TestVerifiersOp::verify() { - if (!getRegion().hasOneBlock()) - return emitOpError("`hasOneBlock` trait hasn't been verified"); - - Operation *definingOp = getInput().getDefiningOp(); - if (definingOp && failed(mlir::verify(definingOp))) - return emitOpError("operand hasn't been verified"); - - emitRemark("success run of verifier"); - - return success(); -} - -LogicalResult TestVerifiersOp::verifyRegions() { - if (!getRegion().hasOneBlock()) - return emitOpError("`hasOneBlock` trait hasn't been verified"); - - for (Block &block : getRegion()) - for (Operation &op : block) - if (failed(mlir::verify(&op))) - return emitOpError("nested op hasn't been verified"); - - emitRemark("success run of region verifier"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// Test InferIntRangeInterface -//===----------------------------------------------------------------------===// - -void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); -} - -ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - // Parse the input argument - OpAsmParser::Argument argInfo; - argInfo.type = parser.getBuilder().getIndexType(); - if (failed(parser.parseArgument(argInfo))) - return failure(); - - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); -} - -void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { - p.printOptionalAttrDict((*this)->getAttrs()); - p << ' '; - p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, - /*omitType=*/true); - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -void TestWithBoundsRegionOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRanges) { - Value arg = getRegion().getArgument(0); - setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); -} - -void TestIncrementOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRanges) { - const ConstantIntRanges &range = argRanges[0]; - APInt one(range.umin().getBitWidth(), 1); - setResultRanges(getResult(), - {range.umin().uadd_sat(one), range.umax().uadd_sat(one), - range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); -} - -void TestReflectBoundsOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRanges) { - const ConstantIntRanges &range = argRanges[0]; - MLIRContext *ctx = getContext(); - Builder b(ctx); - setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); - setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); - setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); - setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); - setResultRanges(getResult(), range); -} - -#include "TestOpEnums.cpp.inc" -#include "TestOpInterfaces.cpp.inc" -#include "TestOpStructs.cpp.inc" -#include "TestTypeInterfaces.cpp.inc" - -#define GET_OP_CLASSES -#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -31,6 +31,12 @@ void registerAttributes(); void registerTypes(); + // Shard op registration hooks. + void registerOps0(); + void registerOps1(); + void registerOps2(); + void registerOps3(); + // Provides a custom printing/parsing for some operations. ::llvm::Optional getParseOperationHook(::llvm::StringRef opName) const override; diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h @@ -0,0 +1,149 @@ +//===- TestFormatUtils.h - MLIR Test Dialect Assembly Format Utilities ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTFORMATUTILS_H +#define MLIR_TESTFORMATUTILS_H + +#include "mlir/IR/OpImplementation.h" + +namespace test { + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperands +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveOperands( + mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand, + llvm::Optional &optOperand, + llvm::SmallVectorImpl &varOperands); + +void printCustomDirectiveOperands(mlir::OpAsmPrinter &printer, + mlir::Operation *, mlir::Value operand, + mlir::Value optOperand, + mlir::OperandRange varOperands); + +//===----------------------------------------------------------------------===// +// CustomDirectiveResults +//===----------------------------------------------------------------------===// + +mlir::ParseResult +parseCustomDirectiveResults(mlir::OpAsmParser &parser, mlir::Type &operandType, + mlir::Type &optOperandType, + llvm::SmallVectorImpl &varOperandTypes); + +void printCustomDirectiveResults(mlir::OpAsmPrinter &printer, mlir::Operation *, + mlir::Type operandType, + mlir::Type optOperandType, + mlir::TypeRange varOperandTypes); + +//===----------------------------------------------------------------------===// +// CustomDirectiveWithTypeRefs +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveWithTypeRefs( + mlir::OpAsmParser &parser, mlir::Type operandType, + mlir::Type optOperandType, + const llvm::SmallVectorImpl &varOperandTypes); + +void printCustomDirectiveWithTypeRefs(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + mlir::Type operandType, + mlir::Type optOperandType, + mlir::TypeRange varOperandTypes); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperandsAndTypes +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveOperandsAndTypes( + mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand, + llvm::Optional &optOperand, + llvm::SmallVectorImpl &varOperands, + mlir::Type &operandType, mlir::Type &optOperandType, + llvm::SmallVectorImpl &varOperandTypes); + +void printCustomDirectiveOperandsAndTypes( + mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::Value operand, + mlir::Value optOperand, mlir::OperandRange varOperands, + mlir::Type operandType, mlir::Type optOperandType, + mlir::TypeRange varOperandTypes); + +//===----------------------------------------------------------------------===// +// CustomDirectiveRegions +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveRegions( + mlir::OpAsmParser &parser, mlir::Region ®ion, + llvm::SmallVectorImpl> &varRegions); + +void printCustomDirectiveRegions( + mlir::OpAsmPrinter &printer, mlir::Operation *, mlir::Region ®ion, + llvm::MutableArrayRef varRegions); + +//===----------------------------------------------------------------------===// +// CustomDirectiveSuccessors +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveSuccessors( + mlir::OpAsmParser &parser, mlir::Block *&successor, + llvm::SmallVectorImpl &varSuccessors); + +void printCustomDirectiveSuccessors(mlir::OpAsmPrinter &printer, + mlir::Operation *, mlir::Block *successor, + mlir::SuccessorRange varSuccessors); + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttributes +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveAttributes(mlir::OpAsmParser &parser, + mlir::IntegerAttr &attr, + mlir::IntegerAttr &optAttr); + +void printCustomDirectiveAttributes(mlir::OpAsmPrinter &printer, + mlir::Operation *, + mlir::Attribute attribute, + mlir::Attribute optAttribute); + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttrDict +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveAttrDict(mlir::OpAsmParser &parser, + mlir::NamedAttrList &attrs); + +void printCustomDirectiveAttrDict(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + mlir::DictionaryAttr attrs); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperandRef +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveOptionalOperandRef( + mlir::OpAsmParser &parser, + llvm::Optional &optOperand); + +void printCustomDirectiveOptionalOperandRef(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + mlir::Value optOperand); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperand +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomOptionalOperand( + mlir::OpAsmParser &parser, + llvm::Optional &optOperand); + +void printCustomOptionalOperand(mlir::OpAsmPrinter &printer, mlir::Operation *, + mlir::Value optOperand); + +} // end namespace test + +#endif // MLIR_TESTFORMATUTILS_H diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp @@ -0,0 +1,263 @@ +//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestFormatUtils.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperands +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveOperands( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, + Optional &optOperand, + SmallVectorImpl &varOperands) { + if (parser.parseOperand(operand)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand)) + return failure(); + } + if (parser.parseArrow() || parser.parseLParen() || + parser.parseOperandList(varOperands) || parser.parseRParen()) + return failure(); + return success(); +} + +void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, + Value operand, Value optOperand, + OperandRange varOperands) { + printer << operand; + if (optOperand) + printer << ", " << optOperand; + printer << " -> (" << varOperands << ")"; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveResults +//===----------------------------------------------------------------------===// + +ParseResult +test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, + Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parser.parseColon()) + return failure(); + + if (parser.parseType(operandType)) + return failure(); + if (succeeded(parser.parseOptionalComma())) + if (parser.parseType(optOperandType)) + return failure(); + if (parser.parseArrow() || parser.parseLParen() || + parser.parseTypeList(varOperandTypes) || parser.parseRParen()) + return failure(); + return success(); +} + +void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, + Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printer << " : " << operandType; + if (optOperandType) + printer << ", " << optOperandType; + printer << " -> (" << varOperandTypes << ")"; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveWithTypeRefs +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveWithTypeRefs( + OpAsmParser &parser, Type operandType, Type optOperandType, + const SmallVectorImpl &varOperandTypes) { + if (parser.parseKeyword("type_refs_capture")) + return failure(); + + Type operandType2, optOperandType2; + SmallVector varOperandTypes2; + if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, + varOperandTypes2)) + return failure(); + + if (operandType != operandType2 || optOperandType != optOperandType2 || + varOperandTypes != varOperandTypes2) + return failure(); + + return success(); +} + +void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, + Operation *op, Type operandType, + Type optOperandType, + TypeRange varOperandTypes) { + printer << " type_refs_capture "; + printCustomDirectiveResults(printer, op, operandType, optOperandType, + varOperandTypes); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperandsAndTypes +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveOperandsAndTypes( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, + Optional &optOperand, + SmallVectorImpl &varOperands, + Type &operandType, Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || + parseCustomDirectiveResults(parser, operandType, optOperandType, + varOperandTypes)) + return failure(); + return success(); +} + +void test::printCustomDirectiveOperandsAndTypes( + OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, + OperandRange varOperands, Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); + printCustomDirectiveResults(printer, op, operandType, optOperandType, + varOperandTypes); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveRegions +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveRegions( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl> &varRegions) { + if (parser.parseRegion(region)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + std::unique_ptr varRegion = std::make_unique(); + if (parser.parseRegion(*varRegion)) + return failure(); + varRegions.emplace_back(std::move(varRegion)); + return success(); +} + +void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, + Region ®ion, + MutableArrayRef varRegions) { + printer.printRegion(region); + if (!varRegions.empty()) { + printer << ", "; + for (Region ®ion : varRegions) + printer.printRegion(region); + } +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveSuccessors +//===----------------------------------------------------------------------===// + +ParseResult +test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, + SmallVectorImpl &varSuccessors) { + if (parser.parseSuccessor(successor)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + Block *varSuccessor; + if (parser.parseSuccessor(varSuccessor)) + return failure(); + varSuccessors.append(2, varSuccessor); + return success(); +} + +void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, + Block *successor, + SuccessorRange varSuccessors) { + printer << successor; + if (!varSuccessors.empty()) + printer << ", " << varSuccessors.front(); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttributes +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser, + IntegerAttr &attr, + IntegerAttr &optAttr) { + if (parser.parseAttribute(attr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(optAttr)) + return failure(); + } + return success(); +} + +void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, + Attribute attribute, + Attribute optAttribute) { + printer << attribute; + if (optAttribute) + printer << ", " << optAttribute; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttrDict +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser, + NamedAttrList &attrs) { + return parser.parseOptionalAttrDict(attrs); +} + +void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs) { + printer.printOptionalAttrDict(attrs.getValue()); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperandRef +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveOptionalOperandRef( + OpAsmParser &parser, Optional &optOperand) { + int64_t operandCount = 0; + if (parser.parseInteger(operandCount)) + return failure(); + bool expectedOptionalOperand = operandCount == 0; + return success(expectedOptionalOperand != optOperand.hasValue()); +} + +void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, + Operation *op, + Value optOperand) { + printer << (optOperand ? "1" : "0"); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperand +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomOptionalOperand( + OpAsmParser &parser, Optional &optOperand) { + if (succeeded(parser.parseOptionalLParen())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand) || parser.parseRParen()) + return failure(); + } + return success(); +} + +void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, + Value optOperand) { + if (optOperand) + printer << "(" << optOperand << ") "; +} diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.h b/mlir/test/lib/Dialect/Test/TestInterfaces.h --- a/mlir/test/lib/Dialect/Test/TestInterfaces.h +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.h @@ -34,4 +34,6 @@ } // namespace TestEffects } // namespace mlir +#include "TestOpInterfaces.h.inc" + #endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp --- a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp @@ -6,3 +6,5 @@ const mlir::SideEffects::Effect *effect) { return isa(effect); } + +#include "TestOpInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp copy from mlir/test/lib/Dialect/Test/TestDialect.cpp copy to mlir/test/lib/Dialect/Test/TestOpDefs.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1,4 +1,4 @@ -//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// +//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,401 +6,25 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" -#include "TestAttributes.h" -#include "TestInterfaces.h" -#include "TestTypes.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "TestOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/ExtensibleDialect.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Reducer/ReductionPatternInterface.h" -#include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSwitch.h" - -// Include this before the using namespace lines below to -// test that we don't have namespace dependencies. -#include "TestOpsDialect.cpp.inc" using namespace mlir; using namespace test; -void test::registerTestDialect(DialectRegistry ®istry) { - registry.insert(); -} - -//===----------------------------------------------------------------------===// -// TestDialect Interfaces -//===----------------------------------------------------------------------===// - -namespace { - -/// Testing the correctness of some traits. -static_assert( - llvm::is_detected::value, - "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); -static_assert(OpTrait::hasSingleBlockImplicitTerminator< - SingleBlockImplicitTerminatorOp>::value, - "hasSingleBlockImplicitTerminator does not match " - "SingleBlockImplicitTerminatorOp"); - -// Test support for interacting with the AsmPrinter. -struct TestOpAsmInterface : public OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; - - AliasResult getAlias(Attribute attr, raw_ostream &os) const final { - StringAttr strAttr = attr.dyn_cast(); - if (!strAttr) - return AliasResult::NoAlias; - - // Check the contents of the string attribute to see what the test alias - // should be named. - Optional aliasName = - StringSwitch>(strAttr.getValue()) - .Case("alias_test:dot_in_name", StringRef("test.alias")) - .Case("alias_test:trailing_digit", StringRef("test_alias0")) - .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) - .Case("alias_test:sanitize_conflict_a", - StringRef("test_alias_conflict0")) - .Case("alias_test:sanitize_conflict_b", - StringRef("test_alias_conflict0_")) - .Case("alias_test:tensor_encoding", StringRef("test_encoding")) - .Default(llvm::None); - if (!aliasName) - return AliasResult::NoAlias; - - os << *aliasName; - return AliasResult::FinalAlias; - } - - AliasResult getAlias(Type type, raw_ostream &os) const final { - if (auto tupleType = type.dyn_cast()) { - if (tupleType.size() > 0 && - llvm::all_of(tupleType.getTypes(), [](Type elemType) { - return elemType.isa(); - })) { - os << "test_tuple"; - return AliasResult::FinalAlias; - } - } - if (auto intType = type.dyn_cast()) { - if (intType.getSignedness() == - TestIntegerType::SignednessSemantics::Unsigned && - intType.getWidth() == 8) { - os << "test_ui8"; - return AliasResult::FinalAlias; - } - } - return AliasResult::NoAlias; - } -}; - -struct TestDialectFoldInterface : public DialectFoldInterface { - using DialectFoldInterface::DialectFoldInterface; - - /// Registered hook to check if the given region, which is attached to an - /// operation that is *not* isolated from above, should be used when - /// materializing constants. - bool shouldMaterializeInto(Region *region) const final { - // If this is a one region operation, then insert into it. - return isa(region->getParentOp()); - } -}; - -/// This class defines the interface for handling inlining with standard -/// operations. -struct TestInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - // Don't allow inlining calls that are marked `noinline`. - return !call->hasAttr("noinline"); - } - bool isLegalToInline(Region *, Region *, bool, - BlockAndValueMapping &) const final { - // Inlining into test dialect regions is legal. - return true; - } - bool isLegalToInline(Operation *, Region *, bool, - BlockAndValueMapping &) const final { - return true; - } - - bool shouldAnalyzeRecursively(Operation *op) const final { - // Analyze recursively if this is not a functional region operation, it - // froms a separate functional scope. - return !isa(op); - } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { - // Only handle "test.return" here. - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } - - /// Attempt to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - // Only allow conversion for i16/i32 types. - if (!(resultType.isSignlessInteger(16) || - resultType.isSignlessInteger(32)) || - !(input.getType().isSignlessInteger(16) || - input.getType().isSignlessInteger(32))) - return nullptr; - return builder.create(conversionLoc, resultType, input); - } - - void processInlinedCallBlocks( - Operation *call, - iterator_range inlinedBlocks) const final { - if (!isa(call)) - return; - - // Set attributed on all ops in the inlined blocks. - for (Block &block : inlinedBlocks) { - block.walk([&](Operation *op) { - op->setAttr("inlined_conversion", UnitAttr::get(call->getContext())); - }); - } - } -}; - -struct TestReductionPatternInterface : public DialectReductionPatternInterface { -public: - TestReductionPatternInterface(Dialect *dialect) - : DialectReductionPatternInterface(dialect) {} - - void populateReductionPatterns(RewritePatternSet &patterns) const final { - populateTestReductionPatterns(patterns); - } -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Dynamic operations -//===----------------------------------------------------------------------===// - -std::unique_ptr getDynamicGenericOp(TestDialect *dialect) { - return DynamicOpDefinition::get( - "dynamic_generic", dialect, [](Operation *op) { return success(); }, - [](Operation *op) { return success(); }); -} - -std::unique_ptr -getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { - return DynamicOpDefinition::get( - "dynamic_one_operand_two_results", dialect, - [](Operation *op) { - if (op->getNumOperands() != 1) { - op->emitOpError() - << "expected 1 operand, but had " << op->getNumOperands(); - return failure(); - } - if (op->getNumResults() != 2) { - op->emitOpError() - << "expected 2 results, but had " << op->getNumResults(); - return failure(); - } - return success(); - }, - [](Operation *op) { return success(); }); -} - -std::unique_ptr -getDynamicCustomParserPrinterOp(TestDialect *dialect) { - auto verifier = [](Operation *op) { - if (op->getNumOperands() == 0 && op->getNumResults() == 0) - return success(); - op->emitError() << "operation should have no operands and no results"; - return failure(); - }; - auto regionVerifier = [](Operation *op) { return success(); }; - - auto parser = [](OpAsmParser &parser, OperationState &state) { - return parser.parseKeyword("custom_keyword"); - }; - - auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { - printer << op->getName() << " custom_keyword"; - }; - - return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, - verifier, regionVerifier, parser, printer); -} - //===----------------------------------------------------------------------===// -// TestDialect +// FormatInferType2Op //===----------------------------------------------------------------------===// -static void testSideEffectOpGetEffect( - Operation *op, - SmallVectorImpl> &effects); - -// This is the implementation of a dialect fallback for `TestEffectOpInterface`. -struct TestOpEffectInterfaceFallback - : public TestEffectOpInterface::FallbackModel< - TestOpEffectInterfaceFallback> { - static bool classof(Operation *op) { - bool isSupportedOp = - op->getName().getStringRef() == "test.unregistered_side_effect_op"; - assert(isSupportedOp && "Unexpected dispatch"); - return isSupportedOp; - } - - void - getEffects(Operation *op, - SmallVectorImpl> - &effects) const { - testSideEffectOpGetEffect(op, effects); - } -}; - -void TestDialect::initialize() { - registerAttributes(); - registerTypes(); - addOperations< -#define GET_OP_LIST -#include "TestOps.cpp.inc" - >(); - registerDynamicOp(getDynamicGenericOp(this)); - registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); - registerDynamicOp(getDynamicCustomParserPrinterOp(this)); - - addInterfaces(); - allowUnknownOperations(); - - // Instantiate our fallback op interface that we'll use on specific - // unregistered op. - fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; -} -TestDialect::~TestDialect() { - delete static_cast( - fallbackEffectOpInterfaces); -} - -Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, - Type type, Location loc) { - return builder.create(loc, type, value); -} - -::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( - ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, - ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, - ::mlir::RegionRange regions, - ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { - inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); - return ::mlir::success(); -} - -void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, - OperationName opName) { - if (opName.getIdentifier() == "test.unregistered_side_effect_op" && - typeID == TypeID::get()) - return fallbackEffectOpInterfaces; - return nullptr; -} - -LogicalResult TestDialect::verifyOperationAttribute(Operation *op, - NamedAttribute namedAttr) { - if (namedAttr.getName() == "test.invalid_attr") - return op->emitError() << "invalid to use 'test.invalid_attr'"; - return success(); -} - -LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, - unsigned regionIndex, - unsigned argIndex, - NamedAttribute namedAttr) { - if (namedAttr.getName() == "test.invalid_attr") - return op->emitError() << "invalid to use 'test.invalid_attr'"; - return success(); -} - -LogicalResult -TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, - unsigned resultIndex, - NamedAttribute namedAttr) { - if (namedAttr.getName() == "test.invalid_attr") - return op->emitError() << "invalid to use 'test.invalid_attr'"; +LogicalResult FormatInferType2Op::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({IntegerType::get(context, 16)}); return success(); } -Optional -TestDialect::getParseOperationHook(StringRef opName) const { - if (opName == "test.dialect_custom_printer") { - return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { - return parser.parseKeyword("custom_format"); - }}; - } - if (opName == "test.dialect_custom_format_fallback") { - return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { - return parser.parseKeyword("custom_format_fallback"); - }}; - } - if (opName == "test.dialect_custom_printer.with.dot") { - return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { - return ParseResult::success(); - }}; - } - return None; -} - -llvm::unique_function -TestDialect::getOperationPrinter(Operation *op) const { - StringRef opName = op->getName().getStringRef(); - if (opName == "test.dialect_custom_printer") { - return [](Operation *op, OpAsmPrinter &printer) { - printer.getStream() << " custom_format"; - }; - } - if (opName == "test.dialect_custom_format_fallback") { - return [](Operation *op, OpAsmPrinter &printer) { - printer.getStream() << " custom_format_fallback"; - }; - } - return {}; -} - //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// @@ -422,7 +46,7 @@ } //===----------------------------------------------------------------------===// -// TestProducingBranchOp +// TestInternalBranchOp //===----------------------------------------------------------------------===// SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { @@ -432,23 +56,6 @@ return SuccessorOperands(1, getErrorOperandsMutable()); } -//===----------------------------------------------------------------------===// -// TestDialectCanonicalizerOp -//===----------------------------------------------------------------------===// - -static LogicalResult -dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, - PatternRewriter &rewriter) { - rewriter.replaceOpWithNewOp( - op, rewriter.getI32IntegerAttr(42)); - return success(); -} - -void TestDialect::getCanonicalizationPatterns( - RewritePatternSet &results) const { - results.add(&dialectCanonicalizationPattern); -} - //===----------------------------------------------------------------------===// // TestCallOp //===----------------------------------------------------------------------===// @@ -465,7 +72,7 @@ } //===----------------------------------------------------------------------===// -// TestFoldToCallOp +// FoldToCallOp //===----------------------------------------------------------------------===// namespace { @@ -485,219 +92,8 @@ MLIRContext *context) { results.add(context); } - -//===----------------------------------------------------------------------===// -// Test Format* operations -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// Parsing - -static ParseResult parseCustomOptionalOperand( - OpAsmParser &parser, Optional &optOperand) { - if (succeeded(parser.parseOptionalLParen())) { - optOperand.emplace(); - if (parser.parseOperand(*optOperand) || parser.parseRParen()) - return failure(); - } - return success(); -} - -static ParseResult parseCustomDirectiveOperands( - OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, - Optional &optOperand, - SmallVectorImpl &varOperands) { - if (parser.parseOperand(operand)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - optOperand.emplace(); - if (parser.parseOperand(*optOperand)) - return failure(); - } - if (parser.parseArrow() || parser.parseLParen() || - parser.parseOperandList(varOperands) || parser.parseRParen()) - return failure(); - return success(); -} -static ParseResult -parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, - Type &optOperandType, - SmallVectorImpl &varOperandTypes) { - if (parser.parseColon()) - return failure(); - - if (parser.parseType(operandType)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseType(optOperandType)) - return failure(); - } - if (parser.parseArrow() || parser.parseLParen() || - parser.parseTypeList(varOperandTypes) || parser.parseRParen()) - return failure(); - return success(); -} -static ParseResult -parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, - Type optOperandType, - const SmallVectorImpl &varOperandTypes) { - if (parser.parseKeyword("type_refs_capture")) - return failure(); - - Type operandType2, optOperandType2; - SmallVector varOperandTypes2; - if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, - varOperandTypes2)) - return failure(); - - if (operandType != operandType2 || optOperandType != optOperandType2 || - varOperandTypes != varOperandTypes2) - return failure(); - - return success(); -} -static ParseResult parseCustomDirectiveOperandsAndTypes( - OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, - Optional &optOperand, - SmallVectorImpl &varOperands, - Type &operandType, Type &optOperandType, - SmallVectorImpl &varOperandTypes) { - if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || - parseCustomDirectiveResults(parser, operandType, optOperandType, - varOperandTypes)) - return failure(); - return success(); -} -static ParseResult parseCustomDirectiveRegions( - OpAsmParser &parser, Region ®ion, - SmallVectorImpl> &varRegions) { - if (parser.parseRegion(region)) - return failure(); - if (failed(parser.parseOptionalComma())) - return success(); - std::unique_ptr varRegion = std::make_unique(); - if (parser.parseRegion(*varRegion)) - return failure(); - varRegions.emplace_back(std::move(varRegion)); - return success(); -} -static ParseResult -parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, - SmallVectorImpl &varSuccessors) { - if (parser.parseSuccessor(successor)) - return failure(); - if (failed(parser.parseOptionalComma())) - return success(); - Block *varSuccessor; - if (parser.parseSuccessor(varSuccessor)) - return failure(); - varSuccessors.append(2, varSuccessor); - return success(); -} -static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, - IntegerAttr &attr, - IntegerAttr &optAttr) { - if (parser.parseAttribute(attr)) - return failure(); - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(optAttr)) - return failure(); - } - return success(); -} - -static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, - NamedAttrList &attrs) { - return parser.parseOptionalAttrDict(attrs); -} -static ParseResult parseCustomDirectiveOptionalOperandRef( - OpAsmParser &parser, Optional &optOperand) { - int64_t operandCount = 0; - if (parser.parseInteger(operandCount)) - return failure(); - bool expectedOptionalOperand = operandCount == 0; - return success(expectedOptionalOperand != optOperand.hasValue()); -} - //===----------------------------------------------------------------------===// -// Printing - -static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, - Value optOperand) { - if (optOperand) - printer << "(" << optOperand << ") "; -} - -static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, - Value operand, Value optOperand, - OperandRange varOperands) { - printer << operand; - if (optOperand) - printer << ", " << optOperand; - printer << " -> (" << varOperands << ")"; -} -static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, - Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printer << " : " << operandType; - if (optOperandType) - printer << ", " << optOperandType; - printer << " -> (" << varOperandTypes << ")"; -} -static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, - Operation *op, Type operandType, - Type optOperandType, - TypeRange varOperandTypes) { - printer << " type_refs_capture "; - printCustomDirectiveResults(printer, op, operandType, optOperandType, - varOperandTypes); -} -static void printCustomDirectiveOperandsAndTypes( - OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, - OperandRange varOperands, Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); - printCustomDirectiveResults(printer, op, operandType, optOperandType, - varOperandTypes); -} -static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, - Region ®ion, - MutableArrayRef varRegions) { - printer.printRegion(region); - if (!varRegions.empty()) { - printer << ", "; - for (Region ®ion : varRegions) - printer.printRegion(region); - } -} -static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, - Block *successor, - SuccessorRange varSuccessors) { - printer << successor; - if (!varSuccessors.empty()) - printer << ", " << varSuccessors.front(); -} -static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, - Attribute attribute, - Attribute optAttribute) { - printer << attribute; - if (optAttribute) - printer << ", " << optAttribute; -} - -static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs) { - printer.printOptionalAttrDict(attrs.getValue()); -} - -static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, - Operation *op, - Value optOperand) { - printer << (optOperand ? "1" : "0"); -} - -//===----------------------------------------------------------------------===// -// Test IsolatedRegionOp - parse passthrough region arguments. +// IsolatedRegionOp - test parsing passthrough operands //===----------------------------------------------------------------------===// ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, @@ -723,7 +119,7 @@ } //===----------------------------------------------------------------------===// -// Test SSACFGRegionOp +// SSACFGRegionOp //===----------------------------------------------------------------------===// RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { @@ -731,7 +127,7 @@ } //===----------------------------------------------------------------------===// -// Test GraphRegionOp +// GraphRegionOp //===----------------------------------------------------------------------===// RegionKind GraphRegionOp::getRegionKind(unsigned index) { @@ -739,7 +135,7 @@ } //===----------------------------------------------------------------------===// -// Test AffineScopeOp +// AffineScopeOp //===----------------------------------------------------------------------===// ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { @@ -754,7 +150,7 @@ } //===----------------------------------------------------------------------===// -// Test parser. +// ParseIntegerLiteralOp //===----------------------------------------------------------------------===// ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, @@ -776,6 +172,10 @@ p << " : " << numResults; } +//===----------------------------------------------------------------------===// +// ParseWrappedKeywordOp +//===----------------------------------------------------------------------===// + ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, OperationState &result) { StringRef keyword; @@ -788,7 +188,8 @@ void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } //===----------------------------------------------------------------------===// -// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. +// WrappingRegionOp - wrapping op exercising `parseGenericOperation()`. +//===----------------------------------------------------------------------===// ParseResult WrappingRegionOp::parse(OpAsmParser &parser, OperationState &result) { @@ -827,14 +228,13 @@ } //===----------------------------------------------------------------------===// -// Test PrettyPrintedRegionOp - exercising the following parser APIs +// PrettyPrintedRegionOp - exercising the following parser APIs // parseGenericOperationAfterOpName // parseCustomOperationName //===----------------------------------------------------------------------===// ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, OperationState &result) { - SMLoc loc = parser.getCurrentLocation(); Location currLocation = parser.getEncodedSourceLoc(loc); @@ -924,7 +324,7 @@ } //===----------------------------------------------------------------------===// -// Test PolyForOp - parse list of region arguments. +// PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { @@ -956,7 +356,7 @@ } //===----------------------------------------------------------------------===// -// Test removing op with inner ops. +// TestRemoveOpWithInnerOps //===----------------------------------------------------------------------===// namespace { @@ -974,19 +374,35 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// TestOpWithRegionPattern +//===----------------------------------------------------------------------===// + void TestOpWithRegionPattern::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); } +//===----------------------------------------------------------------------===// +// TestOpWithRegionFold +//===----------------------------------------------------------------------===// + OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { return getOperand(); } +//===----------------------------------------------------------------------===// +// TestOpConstant +//===----------------------------------------------------------------------===// + OpFoldResult TestOpConstant::fold(ArrayRef operands) { return getValue(); } +//===----------------------------------------------------------------------===// +// TestOpWithVariadicResultsAndFolder +//===----------------------------------------------------------------------===// + LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { for (Value input : this->getOperands()) { @@ -995,6 +411,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// TestOpInPlaceFold +//===----------------------------------------------------------------------===// + OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { assert(operands.size() == 1); if (operands.front()) { @@ -1004,10 +424,18 @@ return {}; } +//===----------------------------------------------------------------------===// +// TestPassthroughFold +//===----------------------------------------------------------------------===// + OpFoldResult TestPassthroughFold::fold(ArrayRef operands) { return getOperand(); } +//===----------------------------------------------------------------------===// +// OpWithInferTypeInterfaceOp +//===----------------------------------------------------------------------===// + LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1021,6 +449,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpWithShapedTypeInferTypeInterfaceOp +//===----------------------------------------------------------------------===// + LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -1046,6 +478,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpWithResultShapeInterfaceOp +//===----------------------------------------------------------------------===// + LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, ValueRange operands, llvm::SmallVectorImpl &shapes) { @@ -1064,6 +500,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// OpWithResultShapePerDimInterfaceOp +//===----------------------------------------------------------------------===// + LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { Location loc = getLoc(); @@ -1081,29 +521,9 @@ } //===----------------------------------------------------------------------===// -// Test SideEffect interfaces +// SideEffectOp //===----------------------------------------------------------------------===// -namespace { -/// A test resource for side effects. -struct TestResource : public SideEffects::Resource::Base { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) - - StringRef getName() final { return ""; } -}; -} // namespace - -static void testSideEffectOpGetEffect( - Operation *op, - SmallVectorImpl> - &effects) { - auto effectsAttr = op->getAttrOfType("effect_parameter"); - if (!effectsAttr) - return; - - effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); -} - void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. @@ -1142,7 +562,11 @@ void SideEffectOp::getEffects( SmallVectorImpl &effects) { - testSideEffectOpGetEffect(getOperation(), effects); + auto effectsAttr = (*this)->getAttrOfType("effect_parameter"); + if (!effectsAttr) + return; + + effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); } //===----------------------------------------------------------------------===// @@ -1216,7 +640,6 @@ // attribute. void StringAttrPrettyNameOp::getAsmResultNames( function_ref setNameFn) { - auto value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) if (auto str = value[i].dyn_cast()) @@ -1224,6 +647,10 @@ setNameFn(getResult(i), str.getValue()); } +//===----------------------------------------------------------------------===// +// CustomResultsNameOp +//===----------------------------------------------------------------------===// + void CustomResultsNameOp::getAsmResultNames( function_ref setNameFn) { ArrayAttr value = getNames(); @@ -1351,13 +778,27 @@ invocationBounds.emplace_back(1, 1); } +//===----------------------------------------------------------------------===// +// SingleBlockImplicitTerminatorO +//===----------------------------------------------------------------------===// + +/// Testing the correctness of some traits. +static_assert( + llvm::is_detected::value, + "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); +static_assert(OpTrait::hasSingleBlockImplicitTerminator< + SingleBlockImplicitTerminatorOp>::value, + "hasSingleBlockImplicitTerminator does not match " + "SingleBlockImplicitTerminatorOp"); + //===----------------------------------------------------------------------===// // SingleNoTerminatorCustomAsmOp //===----------------------------------------------------------------------===// ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, - OperationState &state) { - Region *body = state.addRegion(); + OperationState &result) { + Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) return failure(); return success(); @@ -1406,11 +847,17 @@ // Test InferIntRangeInterface //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TestWithBoundsOp + void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); } +//===----------------------------------------------------------------------===// +// TestWithBoundsRegionOp + ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) @@ -1442,6 +889,9 @@ setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); } +//===----------------------------------------------------------------------===// +// TestIncrementOp + void TestIncrementOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { const ConstantIntRanges &range = argRanges[0]; @@ -1451,6 +901,9 @@ range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); } +//===----------------------------------------------------------------------===// +// TestReflectBoundsOp + void TestReflectBoundsOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRanges) { const ConstantIntRanges &range = argRanges[0]; @@ -1462,11 +915,3 @@ setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); setResultRanges(getResult(), range); } - -#include "TestOpEnums.cpp.inc" -#include "TestOpInterfaces.cpp.inc" -#include "TestOpStructs.cpp.inc" -#include "TestTypeInterfaces.cpp.inc" - -#define GET_OP_CLASSES -#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestOps.h copy from mlir/test/lib/Dialect/Test/TestDialect.h copy to mlir/test/lib/Dialect/Test/TestOps.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -1,34 +1,24 @@ -//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===// +//===- TestOps.h - MLIR Test Dialect Operations ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file defines a fake 'test' dialect that can be used for testing things -// that do not have a respective counterpart in the main source directories. -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_TESTDIALECT_H -#define MLIR_TESTDIALECT_H +#ifndef MLIR_TESTOPS_H +#define MLIR_TESTOPS_H #include "TestAttributes.h" #include "TestInterfaces.h" -#include "mlir/Dialect/DLTI/DLTI.h" +#include "TestTypes.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Traits.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/ExtensibleDialect.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" -#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" @@ -37,23 +27,15 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/ViewLikeInterface.h" -namespace mlir { -class DLTIDialect; -class RewritePatternSet; -} // namespace mlir - -#include "TestOpInterfaces.h.inc" -#include "TestOpStructs.h.inc" -#include "TestOpsDialect.h.inc" +namespace test { +/// A test resource for side effects. +struct TestResource : public mlir::SideEffects::Resource::Base { + llvm::StringRef getName() final { return ""; } +}; +} // namespace test #define GET_OP_CLASSES #include "TestOps.h.inc" -namespace test { -void registerTestDialect(::mlir::DialectRegistry ®istry); -void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns); -} // namespace test - -#endif // MLIR_TESTDIALECT_H +#endif // MLIR_TESTOPS_H diff --git a/mlir/test/lib/Dialect/Test/TestOps0.cpp b/mlir/test/lib/Dialect/Test/TestOps0.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOps0.cpp @@ -0,0 +1,24 @@ +//===- TestOps0.cpp - MLIR Test Dialect Operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestFormatUtils.h" +#include "TestOps.h" + +using namespace mlir; +using namespace test; + +void test::TestDialect::registerOps0() { + addOperations< +#define GET_OP_LIST_0 +#include "TestOps.cpp.inc" + >(); +} + +#define GET_OP_CLASSES_0 +#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps1.cpp b/mlir/test/lib/Dialect/Test/TestOps1.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOps1.cpp @@ -0,0 +1,24 @@ +//===- TestOps1.cpp - MLIR Test Dialect Operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestFormatUtils.h" +#include "TestOps.h" + +using namespace mlir; +using namespace test; + +void test::TestDialect::registerOps1() { + addOperations< +#define GET_OP_LIST_1 +#include "TestOps.cpp.inc" + >(); +} + +#define GET_OP_CLASSES_1 +#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps2.cpp b/mlir/test/lib/Dialect/Test/TestOps2.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOps2.cpp @@ -0,0 +1,23 @@ +//===- TestOps2.cpp - MLIR Test Dialect Operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestOps.h" + +using namespace mlir; +using namespace test; + +void test::TestDialect::registerOps2() { + addOperations< +#define GET_OP_LIST_2 +#include "TestOps.cpp.inc" + >(); +} + +#define GET_OP_CLASSES_2 +#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps3.cpp b/mlir/test/lib/Dialect/Test/TestOps3.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOps3.cpp @@ -0,0 +1,23 @@ +//===- TestOps3.cpp - MLIR Test Dialect Operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestOps.h" + +using namespace mlir; +using namespace test; + +void test::TestDialect::registerOps3() { + addOperations< +#define GET_OP_LIST_3 +#include "TestOps.cpp.inc" + >(); +} + +#define GET_OP_CLASSES_3 +#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" -#include "TestTypes.h" +#include "TestOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -30,11 +30,11 @@ /// FieldInfo represents a field in the StructType data type. It is used as a /// parameter in TestTypeDefs.td. struct FieldInfo { - ::llvm::StringRef name; - ::mlir::Type type; + llvm::StringRef name; + mlir::Type type; // Custom allocation called from generated constructor code - FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const { + FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const { return FieldInfo{alloc.copyInto(name), type}; } }; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -140,6 +140,7 @@ // Tablegen Generated Definitions //===----------------------------------------------------------------------===// +#include "TestTypeInterfaces.cpp.inc" #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp --- a/mlir/test/lib/IR/TestClone.cpp +++ b/mlir/test/lib/IR/TestClone.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp --- a/mlir/test/lib/IR/TestTypes.cpp +++ b/mlir/test/lib/IR/TestTypes.cpp @@ -8,6 +8,7 @@ #include "TestTypes.h" #include "TestDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp --- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp +++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -12,8 +12,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "TestOps.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -19,6 +19,7 @@ #include "../../test/lib/Dialect/Test/TestAttributes.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestTypes.h" #include "mlir/IR/OwningOpRef.h" diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp --- a/mlir/unittests/IR/PatternMatchTest.cpp +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -10,6 +10,7 @@ #include "gtest/gtest.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" using namespace mlir; diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -77,8 +78,8 @@ verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs); // Test collective params build method. - op = - builder.create(loc, TypeRange{i32Ty}, ValueRange{*cstI32, *cstI32}); + op = builder.create(loc, TypeRange{i32Ty}, + ValueRange{*cstI32, *cstI32}); verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs); // Test build method with no result types, default value of attributes. diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -60,7 +60,10 @@ "lib/Dialect/Test/TestOps.h.inc", ), ( - ["-gen-op-defs"], + [ + "-gen-op-defs", + "-op-shard-count=4", + ], "lib/Dialect/Test/TestOps.cpp.inc", ), (