diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1056,20 +1056,6 @@ //===------------------------------------------------------------------===// // Other static interface methods. //===------------------------------------------------------------------===// - StaticInterfaceMethod< - /*desc=*/[{ - Create an operation of the current type with the given location, - operands, and attributes. - }], - /*retTy=*/"Operation *", - /*methodName=*/"create", - (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands, - "ArrayRef":$attributes), [{ - return builder.create( - loc, resultTypes, operands, attributes); - }] - >, InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location and operands. This @@ -1082,14 +1068,13 @@ (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, "ValueRange":$operands), [{ - BlockAndValueMapping map; - unsigned numRegions = $_op->getNumRegions(); - Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs()); - assert(res->getNumRegions() == numRegions && "inconsistent # regions"); - for (unsigned ridx = 0; ridx < numRegions; ++ridx) - $_op->getRegion(ridx).cloneInto( - &res->getRegion(ridx), map); - return res; + BlockAndValueMapping bvm; + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (Region &r : $_op->getRegions()) + r.cloneInto(state.addRegion(), bvm); + return b.createOperation(state); }] >, StaticInterfaceMethod< @@ -1098,7 +1083,7 @@ Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -110,14 +110,13 @@ AnyStridedMemRef:$output, OptionalAttr:$inputPermutation, OptionalAttr:$outputPermutation); + let regions = (region AnyRegion:$region); - // TODO: this should go away once the usage of OptionalAttr triggers emission - // of builders with default arguments left unspecified. - let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output), - [{ - return build( - $_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr()); - }]>]; + let builders = [ + OpBuilderDAG<(ins "Value":$input, "Value":$output, + CArg<"AffineMap", "AffineMap()">:$inputPermutation, + CArg<"AffineMap", "AffineMap()">:$outputPermutation, + CArg<"ArrayRef", "{}">:$attrs)>]; let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return getOperands().take_front(); } @@ -146,24 +145,31 @@ Value getSource() { return input();} Value getTarget() { return output(); } - static std::function getRegionBuilder() { - return nullptr; + static void regionBuilder(Block &block, ValueRange captures); + static std::function + getRegionBuilder() { + return ®ionBuilder; } + static unsigned getNumRegionArgs() { return 2; } }]; let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ - `(` operands `)` attr-dict `:` type(operands) + `(` $input `,` $output `)` attr-dict `:` + type($input) `,` type($output) + custom($region, ref(type($input)), ref(type($input))) }]; let hasFolder = 1; let hasCanonicalizer = 1; + let skipDefaultBuilders = 1; } def FillOp : LinalgStructured_Op<"fill", []> { let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); let results = (outs Optional:$result); + let regions = (region AnyRegion:$region); let extraClassDeclaration = structuredOpsDecls # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -183,13 +189,18 @@ extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static std::function getRegionBuilder() { - return nullptr; + static void regionBuilder(Block &block, ValueRange captures); + static std::function + getRegionBuilder() { + return ®ionBuilder; } + static unsigned getNumRegionArgs() { return 1; } }]; let assemblyFormat = [{ - `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)? + `(` $output `,` $value `)` attr-dict `:` + type($output) `,` type($value) (`->` type($result)^)? + custom($region, ref(type($output)), ref($value)) }]; let builders = [ @@ -268,7 +279,8 @@ return padding().getValue().getValue({i, 1}); } - static std::function getRegionBuilder() { + static std::function getRegionBuilder() + { return nullptr; } }]; @@ -519,7 +531,7 @@ library_call()->str() : "op_has_no_registered_library_name"; } - static std::function getRegionBuilder() { + static std::function getRegionBuilder() { return nullptr; } }]; diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -154,7 +154,13 @@ if (in == op.input() && out == op.output()) return failure(); - rewriter.replaceOpWithNewOp(op, in, out); + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); + + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out})); return success(); } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -27,8 +27,6 @@ ArrayRef outputs, TypeRange resultTensorTypes, function_ref regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { - OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); - // Build maps SmallVector, 4> exprsList; exprsList.reserve(inputs.size() + outputs.size()); @@ -54,13 +52,10 @@ resultTensorTypes, inputValues, outputValues, - builder.getAffineMapArrayAttr(maps), - builder.getStrArrayAttr(iteratorStrTypes), - StringAttr() /*doc*/, - StringAttr() /*library_call*/, - ArrayAttr() /*sparse*/ - /* TODO: other attributes in op */ - ) + maps, + iteratorStrTypes, + ""/*doc*/, + ""/*library_call*/) .getOperation(); // clang-format on diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -33,32 +33,53 @@ using namespace mlir::linalg; /// Forward declarations. + +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted +/// to be ShapedType. +template +static void fillStructuredOpRegion( + OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, + TypeRange outputTypes, ValueRange captures = {}, + std::function errorHandler = [](unsigned, + unsigned) {}); + +/// Generic entry point to create both the region and the block of a LinalgOp. template -static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes); +static void +createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, + TypeRange inputTypes, TypeRange outputTypes, + ValueRange captures = {}); +/// Common parsing and printing used for both named structured ops created by +/// ods-gen and by manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes); +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op); +/// Specific parsing and printing for named structured ops created by ods-gen. template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes); + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef captures = {}); + static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result); - -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op); +static ParseResult +parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, + ArrayRef captures = {}); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); @@ -83,14 +104,136 @@ return success(folded); } +//===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// +void CopyOp::regionBuilder(Block &block, ValueRange captures) { + using namespace edsc::intrinsics; + assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); + (linalg_yield(block.getArgument(0))); +} + +void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, + Value output, AffineMap inputPermutation, + AffineMap outputPermutation, + ArrayRef namedAttrs) { + result.addOperands({input, output}); + result.addAttributes(namedAttrs); + if (inputPermutation) + result.addAttribute("inputPermutation", + AffineMapAttr::get(inputPermutation)); + if (outputPermutation) + result.addAttribute("outputPermutation", + AffineMapAttr::get(outputPermutation)); + result.addRegion(); + fillStructuredOpRegion(builder, *result.regions.front(), + TypeRange{input.getType()}, + TypeRange{output.getType()}); +} + +ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, + Type outputType) { + OpBuilder opBuilder(parser.getBuilder().getContext()); + fillStructuredOpRegion(opBuilder, r, TypeRange{inputType}, + TypeRange{outputType}); + return success(); +} + +/// CopyOp region is elided when printing. +void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {} + +static LogicalResult verify(CopyOp op) { + auto outputViewType = op.getOutputShapedType(0); + auto inputViewType = op.getInputShapedType(0); + if (inputViewType.getElementType() != outputViewType.getElementType()) + return op.emitOpError("expects views of the same type"); + if (inputViewType.getRank() != outputViewType.getRank()) + return op.emitOpError("expects views of the same rank"); + auto rank = op.getNumParallelLoops(); + auto inputPermutationMap = op.inputPermutation(); + if (inputPermutationMap) { + if (inputPermutationMap->getNumInputs() != rank) + return op.emitOpError("expects optional input_permutation map of rank ") + << rank; + if (!inputPermutationMap->isPermutation()) + return op.emitOpError( + "expects optional input_permutation map to be a permutation"); + } + auto outputPermutationMap = op.outputPermutation(); + if (outputPermutationMap) { + if (outputPermutationMap->getNumInputs() != rank) + return op.emitOpError("expects optional output_permutation map of rank ") + << rank; + if (!outputPermutationMap->isPermutation()) + return op.emitOpError( + "expects optional output_permutation map to be a permutation"); + } + if (rank == 0 && inputPermutationMap) + return op.emitOpError("expected no input permutation when rank == 0"); + if (rank == 0 && outputPermutationMap) + return op.emitOpError("expected no output permutation when rank == 0"); + return success(); +} + +void CopyOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// +void FillOp::regionBuilder(Block &block, ValueRange captures) { + using namespace edsc::intrinsics; + assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture"); + (linalg_yield(captures)); +} void FillOp::build(OpBuilder &builder, OperationState &result, Value output, Value value) { build(builder, result, output.getType().dyn_cast(), output, value); + fillStructuredOpRegion(builder, *result.regions.front(), TypeRange{}, + TypeRange{output.getType()}, value); +} + +ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType, + OpAsmParser::OperandType valueRef) { + OpBuilder opBuilder(parser.getBuilder().getContext()); + // Resolve `valueRef` into `value` at parse time so we can build the region + // with captures. + SmallVector value; + parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value); + fillStructuredOpRegion(opBuilder, r, TypeRange{}, + TypeRange{outputType}, value); + return success(); +} + +/// FillOp region is elided when printing. +void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {} + +static LogicalResult verify(FillOp op) { + auto viewType = op.getOutputShapedType(0); + auto fillType = op.value().getType(); + if (viewType.getElementType() != fillType) + return op.emitOpError("expects fill type to match view elemental type"); + if (!op.getNumResults() && !viewType.isa()) { + return op.emitOpError( + "expected fill op with no result value to use memref type"); + } + return success(); +} + +void FillOp::getEffects( + SmallVectorImpl> + &effects) { + if (output().getType().isa()) + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// @@ -397,7 +540,6 @@ // InitTensorOp //===----------------------------------------------------------------------===// - static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( @@ -1396,68 +1538,6 @@ /////// Operations corresponding to library calls defined with Tablegen //////// -void FillOp::getEffects( - SmallVectorImpl> - &effects) { - if (output().getType().isa()) - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); -} - -static LogicalResult verify(FillOp op) { - auto viewType = op.getOutputShapedType(0); - auto fillType = op.value().getType(); - if (viewType.getElementType() != fillType) - return op.emitOpError("expects fill type to match view elemental type"); - if (!op.getNumResults() && !viewType.isa()) { - return op.emitOpError( - "expected fill op with no result value to use memref type"); - } - return success(); -} - -void CopyOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), input(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), output(), - SideEffects::DefaultResource::get()); -} - -static LogicalResult verify(CopyOp op) { - auto outputViewType = op.getOutputShapedType(0); - auto inputViewType = op.getInputShapedType(0); - if (inputViewType.getElementType() != outputViewType.getElementType()) - return op.emitOpError("expects views of the same type"); - if (inputViewType.getRank() != outputViewType.getRank()) - return op.emitOpError("expects views of the same rank"); - auto rank = op.getNumParallelLoops(); - auto inputPermutationMap = op.inputPermutation(); - if (inputPermutationMap) { - if (inputPermutationMap->getNumInputs() != rank) - return op.emitOpError("expects optional input_permutation map of rank ") - << rank; - if (!inputPermutationMap->isPermutation()) - return op.emitOpError( - "expects optional input_permutation map to be a permutation"); - } - auto outputPermutationMap = op.outputPermutation(); - if (outputPermutationMap) { - if (outputPermutationMap->getNumInputs() != rank) - return op.emitOpError("expects optional output_permutation map of rank ") - << rank; - if (!outputPermutationMap->isPermutation()) - return op.emitOpError( - "expects optional output_permutation map to be a permutation"); - } - if (rank == 0 && inputPermutationMap) - return op.emitOpError("expected no input permutation when rank == 0"); - if (rank == 0 && outputPermutationMap) - return op.emitOpError("expected no output permutation when rank == 0"); - return success(); -} - template static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, ArrayRef attrs, @@ -1690,14 +1770,25 @@ } //===----------------------------------------------------------------------===// -// Auto-generated Linalg named ops. +// Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted +/// to be ShapedType. template -static void buildNamedStructuredOpRegionAndAttributesImpl( - OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, - std::function errorHandler) { +static void +fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + ValueRange captures, + std::function errorHandler) { + assert(llvm::all_of(inputTypes, [](Type t) { return t.isa(); })); + assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. SmallVector argTypes; @@ -1707,7 +1798,7 @@ // RAII. OpBuilder::InsertionGuard guard(opBuilder); - Block *body = opBuilder.createBlock(®ion, {}, argTypes); + Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) @@ -1715,53 +1806,30 @@ opBuilder.setInsertionPointToStart(body); mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); - NamedStructuredOpType::regionBuilder(*body); + NamedStructuredOpType::regionBuilder(*body, captures); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } +/// Generic entry point to create both the region and the block of a LinalgOp. template -void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes) { +void createAndFillStructuredOpRegion(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes, + ValueRange captures) { Region ®ion = *result.addRegion(); - buildNamedStructuredOpRegionAndAttributesImpl( - opBuilder, region, inputTypes, outputTypes, + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, captures, [&](unsigned expected, unsigned actual) { - llvm::errs() << "region expects " << expected << " args, got " - << actual; assert(expected != actual && "incorrect number of arguments"); }); } -template -static ParseResult -parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes) { - ParseResult res = success(); - OpBuilder opBuilder(parser.getBuilder().getContext()); - buildNamedStructuredOpRegionAndAttributesImpl( - opBuilder, region, inputTypes, outputTypes, - [&](unsigned expected, unsigned actual) { - res = parser.emitError(parser.getCurrentLocation(), - llvm::formatv("region expects {0} args, got {1}", - expected, actual)); - }); - return res; -} - -static ParseResult -parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes) { - if (succeeded(parser.parseOptionalArrow())) - if (parser.parseTypeList(resultTypes)) - return failure(); - return success(); -} - +/// Common parsing used for both named structured ops created by ods-gen and by +/// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, @@ -1802,8 +1870,56 @@ } template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result) { +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op) { + if (!op.inputs().empty()) + p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; + if (!op.outputs().empty()) + p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; +} + +//===----------------------------------------------------------------------===// +// Specific parsing and printing for named structured ops created by ods-gen. +//===----------------------------------------------------------------------===// + +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef captures) { + ParseResult res = success(); + OpBuilder opBuilder(parser.getBuilder().getContext()); + // Resolve `captures` into `capturedValues` at parse time so we can build the + // region with captures. + SmallVector capturedValues; + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, capturedValues, + [&](unsigned expected, unsigned actual) { + res = parser.emitError( + parser.getCurrentLocation(), + llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " + "region expects {0} args, got {1}", + expected, actual)); + region.front().dump(); + }); + return res; +} + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes) { + if (succeeded(parser.parseOptionalArrow())) + if (parser.parseTypeList(resultTypes)) + return failure(); + return success(); +} + +template +static ParseResult +parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, + ArrayRef captures) { + // TODO: Enable when ods-gen supports captures. + assert(captures.empty() && "unexpected captures for named structured ops"); SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); @@ -1817,7 +1933,7 @@ std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes)) + parser, *region, inputTypes, outputTypes, captures)) return failure(); result.addRegion(std::move(region)); @@ -1831,15 +1947,6 @@ p.printOptionalArrowTypeList(resultTypes); } -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op) { - if (!op.inputs().empty()) - p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; - if (!op.outputs().empty()) - p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; -} - template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { p << op.getOperationName(); @@ -1861,6 +1968,10 @@ return verifyGenericOp(op); } +//===----------------------------------------------------------------------===// +// Canonicalizers and Folders. +//===----------------------------------------------------------------------===// + namespace { struct EraseDeadLinalgOp : public RewritePattern { EraseDeadLinalgOp(PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -49,7 +49,7 @@ indexingMaps, iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { edsc::ScopedContext scope(bodyBuilder, loc); - regionBuilder(*bodyBuilder.getBlock()); + regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{}); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -52,14 +52,6 @@ return res; } -static SmallVector permuteIvs(ArrayRef ivs, - Optional permutation) { - return permutation ? applyMapToValues(ScopedContext::getBuilderRef(), - ScopedContext::getLocation(), - permutation.getValue(), ivs) - : SmallVector(ivs.begin(), ivs.end()); -} - template static void inlineRegionAndEmitStore(OpType op, ArrayRef indexedValues, ArrayRef> indexing, @@ -178,40 +170,6 @@ outputBuffers); } -template -static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { - assert(copyOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto nPar = copyOp.getNumParallelLoops(); - assert(nPar == allIvs.size()); - auto inputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); - auto outputIvs = - permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); - // Emit the proper scalar assignment, whether we are dealing with a 0-D or - // an n-D loop nest; with or without permutations. - // clang-format off - nPar > 0 ? O(oivs) = I(iivs) : - O() = I(); - // clang-format on -} - -template -static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { - assert(fillOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - auto nPar = fillOp.getNumParallelLoops(); - assert(nPar == allIvs.size()); - auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutputBuffer(0)); - // Emit the proper scalar assignment, whether we are dealing with a 0-D or - // an n-D loop nest; with or without permutations. - nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); -} - // Create a padded view into the given `input` tensor using the 'indices' // to access the tensor. `skipPadding` lists the dimensions for which no padding // is needed e.g. the non-spatial dimensions for convolutions. @@ -533,8 +491,8 @@ assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); llvm::TypeSwitch(op) - .Case([&](auto op) { + .Case([&](auto op) { emitScalarImplementation(allIvs, op); }) .Default([&](Operation *op) { assert(false && "unexpected op"); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -267,7 +267,7 @@ llvm::map_range(linalgOp.getShapedOperandTypes(), [](ShapedType t) { return t.getElementType(); })); block->addArguments(elementTypes); - linalgOp.getRegionBuilder()(*block); + linalgOp.getRegionBuilder()(*block, /*captures=*/{}); } Block *block = ®ion->front(); @@ -333,24 +333,26 @@ // Return true if the op is an element-wise linalg op. static bool isElementwise(Operation *op) { - auto genericOp = dyn_cast(op); - if (!genericOp) + auto linalgOp = dyn_cast(op); + if (!linalgOp) return false; - if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) return false; // TODO: relax the restrictions on indexing map. - for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) { - if (!genericOp.getOutputIndexingMap(i).isIdentity()) + for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { + if (!linalgOp.getOutputIndexingMap(i).isIdentity()) return false; } // Currently bound the input indexing map to minor identity as other // permutations might require adding transpose ops to convert the vector read // to the right shape. - for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) { - if (!genericOp.getInputIndexingMap(i).isMinorIdentity()) + for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) { + if (!linalgOp.getInputIndexingMap(i).isMinorIdentity()) return false; } - return hasOnlyScalarElementwiseOp(genericOp.getRegion()); + if (linalgOp->getNumRegions() != 1) + return false; + return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); } static Optional vectorizeContraction(OpBuilder &builder, @@ -393,9 +395,6 @@ for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - - if (isa(op)) - return success(); if (isElementwise(op)) return success(); return success(isaContractionOpInterface(linalgOp)); @@ -407,43 +406,12 @@ return llvm::None; edsc::ScopedContext scope(builder, op->getLoc()); - // In the case of 0-D memrefs, return null and special case to scalar load or - // store later. - if (auto fillOp = dyn_cast(op)) { - // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Rewrite linalg.fill as vector.broadcast: " << *op); - VectorizedLinalgOp res; - if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output())) - res.tensorResults.push_back(v); - return res; - } - if (auto copyOp = dyn_cast(op)) { - // Vectorize copy as a vector.transfer_read+vector.transfer_write. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Rewrite linalg.copy as vector.transfer_read + " - "vector.transfer_write: " - << *op); - Value vector = buildVectorRead(builder, copyOp.input()); - VectorizedLinalgOp res; - if (Value v = buildVectorWrite(builder, vector, copyOp.output())) - res.tensorResults.push_back(v); - return res; - } if (isElementwise(op)) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " << "Vectorize linalg op as a generic: " << *op); return vectorizeAsLinalgGeneric(builder, cast(op)); } - // TODO: as soon as Copy and FillOp. get a region builder, replace all the - // above by: - // if (isa(op) || isElementwise(op)) { - // LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - // << "Vectorize linalg op as a generic: " << *op); - // return vectorizeAsLinalgGeneric(builder, cast(op)); - // } - return vectorizeContraction(builder, cast(op)); } diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir --- a/mlir/test/Transforms/copy-removal.mlir +++ b/mlir/test/Transforms/copy-removal.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -copy-removal -split-input-file %s -//| FileCheck %s +// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s // All linalg copies except the linalg.copy(%1, %9) must be removed since the // defining operation of %1 and its DeallocOp have been defined in another block. @@ -256,7 +255,7 @@ %tmp2 = math.exp %gen2_arg0 : f32 linalg.yield %tmp2 : f32 } - "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () + linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32> dealloc %temp : memref<2xf32> // CHECK: return return @@ -292,7 +291,7 @@ linalg.yield %tmp2 : f32 } // CHECK: linalg.copy - "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> () + linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32> dealloc %temp : memref<2xf32> return } @@ -355,7 +354,7 @@ } // CHECK-NOT: linalg.copy // CHECK-NOT: dealloc - "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> () + linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32> dealloc %0 : memref<4xf32> //CHECK: return return diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -23,7 +23,7 @@ // IMPL-NEXT: map2 = simplifyAffineMap(map2); // IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // -// IMPL: void Test1Op::regionBuilder(Block &block) { +// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -47,7 +47,7 @@ // IMPL: AffineMap::get(3, 3, {d2, d1}, context) // IMPL: AffineMap::get(3, 3, {d0, d1}, context) // -// IMPL: Test2Op::regionBuilder(Block &block) { +// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); @@ -71,7 +71,7 @@ // IMPL: AffineMap::get(4, 4, {d3, d2}, context) // IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // -// IMPL: Test3Op::regionBuilder(Block &block) { +// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1871,11 +1871,11 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - buildNamedStructuredOpRegionAndAttributes<{0}>( + createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs)/*, TODO: support captures*/); }]>, OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -1889,11 +1889,11 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - buildNamedStructuredOpRegionAndAttributes<{0}>( + createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs)/*, TODO: support captures*/); }]>, OpBuilderDAG< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -1907,7 +1907,9 @@ {6} ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; - let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }]; + let parser = [{{ + return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); + }]; let hasFolder = 1; let hasCanonicalizer = 1; @@ -1915,8 +1917,8 @@ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(Block &block); - static std::function getRegionBuilder() {{ + static void regionBuilder(Block &block, ValueRange captures); + static std::function getRegionBuilder() {{ return regionBuilder; } @@ -1980,11 +1982,11 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - buildNamedStructuredOpRegionAndAttributes<{0}>( + createAndFillStructuredOpRegion<{0}>( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs)/*, TODO: support captures*/); {2} }]> )FMT"; @@ -2311,7 +2313,7 @@ }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(Block &block) { + void {0}::regionBuilder(Block &block, ValueRange captures) { using namespace edsc; using namespace intrinsics; auto args = block.getArguments();