diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -0,0 +1,3 @@ +def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { + C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); +} diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -806,11 +806,22 @@ def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">; class LinalgNamedStructured_Op props> - : Op { + : LinalgStructuredBase_Op { string spec = ?; - let assemblyFormat = "`(` operands `)` attr-dict `:` " - "functional-type(operands, results)"; + // We cannot use an assemblyFormat atm because we need to hook in a custom- + // built implicit region from a static OpClass method. + // TODO(ntv): Revisit in the future if/when appropriate. + // let assemblyFormat = "`(` operands `)` attr-dict `:` " + // "functional-type(operands, results)"; + + // The parser needs to specialize on the OpType so it has to be auto-generated + // in the linalg-ods tool. + let printer = [{ return ::printNamedStructuredOp(p, *this); }]; + let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; + let hasFolder = 1; } +// This file is auto-generated from a tc specification. +include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td" + #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -64,7 +64,8 @@ "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ - Return the type of the input shape at the given index. + Return the `i`-th input shaped type, irrespective of buffer of tensor + type. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Return the subset of input operands that are of ranked tensor type. @@ -89,6 +90,10 @@ InterfaceMethod<[{ Return the type of the output buffer at the given index. }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Return the `i`-th output shaped type, irrespective of buffer of tensor + type. + }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Return the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" @@ -119,7 +120,8 @@ return it - getInputs().begin(); return llvm::None; } - /// Return the `i`-th input buffer type. + /// Return the `i`-th input shaped type, irrespective of buffer of tensor + /// type. ShapedType getInputShapedType(unsigned i) { return getInput(i).getType().template cast(); } @@ -344,6 +346,17 @@ } }; +/// This class provides the API for named Linalg StructuredOps. +template +class NamedStructuredOpTraits + : public OpTrait::TraitBase { +public: + llvm::Optional> referenceIterators(); + llvm::Optional> referenceIndexingMaps(); + std::function)> + emitScalarImplementation(); +}; + } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" @@ -30,6 +31,20 @@ using namespace mlir; using namespace mlir::linalg; +/// Forward declarations. +template +static void buildNamedStructuredOpRegion(Builder &builder, + OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes); +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); +template +static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); + /// Determines whether it is possible to fold it away in the parent Linalg op: /// /// ```mlir @@ -203,7 +218,9 @@ return success(); } -LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { +template <> +LogicalResult BlockArgsVerifier::verify(IndexedGenericOp op, + Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); auto nOperands = op.getNumOperands(); @@ -245,7 +262,8 @@ auto ®ion = op.region(); if (region.getBlocks().size() != 1) return op.emitOpError("expected region with 1 block"); - if (failed(verifyBlockArgs(op, region.getBlocks().front()))) + if (failed(BlockArgsVerifier::verify( + op, region.getBlocks().front()))) return failure(); SmallVector indexingMaps; @@ -737,17 +755,17 @@ parser.resolveOperands(opInfo, types, loc, result.operands)); } -template -static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { +static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) { // The operand number and types must match the view element types. - auto nOutputs = genericOp.getNumOutputs(); + auto nOutputs = linalgOpInterface.getNumOutputs(); if (op.getNumOperands() != nOutputs) return op.emitOpError("expected number of yield values (") << nOutputs << ") to match the number of operands of the enclosing " - << "linalg.generic op (" << op.getNumOperands() << ")"; + << "LinalgOp (" << op.getNumOperands() << ")"; for (unsigned i = 0; i != nOutputs; ++i) { - auto elementType = genericOp.getOutputShapedType(i).getElementType(); + auto elementType = + linalgOpInterface.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) return op.emitOpError("type of yield operand ") << (i + 1) << " (" << op.getOperand(i).getType() @@ -763,17 +781,10 @@ if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); - auto genericOp = dyn_cast(parentOp); - if (genericOp) - return verifyYield(op, genericOp); - - auto indexedGenericOp = dyn_cast(parentOp); - if (indexedGenericOp) - return verifyYield(op, indexedGenericOp); + if (auto linalgOp = dyn_cast(parentOp)) + return verifyYield(op, cast(parentOp)); - return op.emitOpError("expected '") - << GenericOp::getOperationName() << "' or '" - << IndexedGenericOp::getOperationName() << "' parent op"; + return op.emitOpError("expected parent op with LinalgOp interface"); } /////// Operations corresponding to library calls defined with Tablegen //////// @@ -1056,3 +1067,92 @@ return getResult(); return {}; } + +//===----------------------------------------------------------------------===// +// Auto-generated Linalg named ops. +//===----------------------------------------------------------------------===// + +template +void buildNamedStructuredOpRegion(Builder &builder, OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes) { + auto *op = Operation::create(builder.getUnknownLoc(), + OperationName("fake_op", builder.getContext()), + ArrayRef{}, ArrayRef{}, + ArrayRef{}, ArrayRef{}, + /*numRegions=*/1, + /*resizableOperandList=*/false); + std::unique_ptr> guard( + (int *)1, [&op](int *) { op->destroy(); }); + + Region &bodyRegion = op->getRegion(0); + Block *body = new Block(); + // TODO(ntv): atm all operands go through getElementTypeOrSelf, + // reconsider when we have evidence we need to. + for (auto t : operandTypes) + body->addArgument(getElementTypeOrSelf(t)); + for (auto t : tensorResultTypes) + body->addArgument(getElementTypeOrSelf(t)); + bodyRegion.push_back(body); + OpBuilder opBuilder(bodyRegion); + mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); + NamedStructuredOpType::regionBuilder(*body); + + // Steal the region and let op be destroyed. + result.addRegion()->takeBody(bodyRegion); +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p << op.getOperationName() << " "; + p.printOptionalAttrDict(op.getAttrs()); + p << "(" << op.getOperands() << ")"; + p << ": (" << op.getOperandTypes() << ")"; + auto outputTensorTypes = op.getResultTypes(); + if (!outputTensorTypes.empty()) + p << " -> (" << outputTensorTypes << ")"; +} + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { + SmallVector operandsInfo; + + // Optional attributes may be added. + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseLParen() || + parser.parseOperandList(operandsInfo) || parser.parseRParen()) + return failure(); + + SmallVector operandTypes; + if (parser.parseColon() || parser.parseLParen() || + parser.parseTypeList(operandTypes) || parser.parseRParen()) + return failure(); + + // Generic ops may specify that a subset of its outputs are tensors. Such + // outputs are specified in the result type. + SmallVector tensorResultTypes; + if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + return failure(); + + if (!tensorResultTypes.empty()) + result.addTypes(tensorResultTypes); + + buildNamedStructuredOpRegion( + parser.getBuilder(), result, operandTypes, tensorResultTypes); + + return parser.resolveOperands(operandsInfo, operandTypes, + parser.getCurrentLocation(), result.operands); +} + +template +static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { + return verifyGenericOp(op); +} + +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" + +// TODO(ntv): Determine whether we can generate the folders and verifiers. +LogicalResult batchmatmulOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -121,8 +121,19 @@ } namespace { + template -class LinalgScopedEmitter {}; +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + LinalgOpType linalgOp) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + llvm_unreachable("NYI"); + linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(), + ScopedContext::getLocation(), allIvs); + } +}; template class LinalgScopedEmitter { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -48,7 +48,7 @@ // ----- func @yield_parent(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}} + // expected-error @+1 {{op expected parent op with LinalgOp interface}} linalg.yield %arg0: memref(off + i)>> } @@ -91,7 +91,7 @@ // ----- func @generic_mismatched_num_returns(%arg0: memref) { - // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (0)}} + // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}} linalg.generic { args_in = 0, args_out = 1, @@ -304,7 +304,7 @@ // ----- func @indexed_generic_result_count(%arg0: memref) { - // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (2)}} + // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -420,3 +420,11 @@ memref, memref<2x3xf32>, memref return } + +// ----- + +func @named_ops(%a3: memref, %b3: memref, %c3: memref) { + // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref'}} + linalg.batchmatmul(%a3, %b3, %c3): (memref, memref, memref) -> () + return +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -620,3 +620,16 @@ // CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] // CHECK-SAME: memref into memref + + +// TODO(ntv): Return tensors need a semantics convention update. +func @named_ops(%a3: memref, %b3: memref, %c3: memref, + %ta3: tensor, %tb3: tensor, %tc3: tensor) { + linalg.batchmatmul(%a3, %b3, %c3): (memref, memref, memref) -> () + linalg.batchmatmul(%ta3, %tb3, %c3): (tensor, tensor, memref) -> () + return +} +// CHECK-LABEL: func @named_ops +// CHECK: linalg.batchmatmul +// CHECK: linalg.batchmatmul + diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -4,72 +4,72 @@ // RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 -test-emit-include-td-header \ // RUN: | mlir-tblgen -gen-op-decls -I %S/../../include -// ODS-LABEL: def matvecOp : LinalgNamedStructured_Op<"matvec", [ +// ODS-LABEL: def test1Op : LinalgNamedStructured_Op<"test1", [ // ODS-NEXT: NInputs<2>, // ODS-NEXT: NOutputs<1>, // ODS-NEXT: NamedStructuredOpTraits]> // -// IMPL-LABEL: matvec::referenceIterators() { +// IMPL-LABEL: test1Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: matvec::referenceIndexingMaps() { +// IMPL: test1Op::referenceIndexingMaps() { // IMPL: AffineMap::get(2, 0, {d0, d1}), // IMPL-NEXT: AffineMap::get(2, 0, {d1}), // IMPL-NEXT: AffineMap::get(2, 0, {d0}) }; // -// IMPL: matvec::regionBuilder(ArrayRef args) { +// IMPL: test1Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { +def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { C(m) = std_addf(std_mulf(A(m, k), B(k))); } -// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [ +// ODS-LABEL: def test2Op : LinalgNamedStructured_Op<"test2", [ // ODS-NEXT: NInputs<2>, // ODS-NEXT: NOutputs<1>, // ODS-NEXT: NamedStructuredOpTraits]> // -// IMPL-LABEL: matmul::referenceIterators() { +// IMPL-LABEL: test2Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: matmul::referenceIndexingMaps() { +// IMPL: test2Op::referenceIndexingMaps() { // IMPL: AffineMap::get(3, 0, {d0, d2}), // IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}), // IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}) }; // -// IMPL: matmul::regionBuilder(ArrayRef args) { +// IMPL: test2Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { +def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } -// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ +// ODS-LABEL: def test3Op : LinalgNamedStructured_Op<"test3", [ // ODS-NEXT: NInputs<2>, // ODS-NEXT: NOutputs<1>, // ODS-NEXT: NamedStructuredOpTraits]> // -// IMPL-LABEL: batchmatmul::referenceIterators() { +// IMPL-LABEL: test3Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: batchmatmul::referenceIndexingMaps() { +// IMPL: test3Op::referenceIndexingMaps() { // IMPL: AffineMap::get(4, 0, {d0, d1, d3}), // IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}), // IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}) }; // -// IMPL: batchmatmul::regionBuilder(ArrayRef args) { +// IMPL: test3Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -// TBLGEN: batchmatmulOp -def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { +// TBLGEN: test3Op +def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -896,7 +896,8 @@ TensorExpr(StringRef name, SmallVectorImpl> &&exprs, ArrayRef reductionDims) - : Expression(Kind::TensorExpr), opId(name), expressions(std::move(exprs)), + : Expression(Kind::TensorExpr), operationName(name), + expressions(std::move(exprs)), reductionDimensions(reductionDims.begin(), reductionDims.end()) {} static bool classof(const Expression *e) { @@ -904,7 +905,7 @@ } bool operator==(const TensorExpr &other) const { - if (opId != other.opId) + if (operationName != other.operationName) return false; if (expressions.size() != other.expressions.size()) return false; @@ -922,7 +923,7 @@ template void visit(Lambda callback) const; - StringRef opId; + StringRef operationName; SmallVector, 4> expressions; SetVector reductionDimensions; }; @@ -995,15 +996,15 @@ StringRef linalgOpName); /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. - void printReferenceIterators(llvm::raw_ostream &os, StringRef opId, + void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. - void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, + void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. - void printRegionBuilder(llvm::raw_ostream &os, StringRef opId, + void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); private: @@ -1364,6 +1365,7 @@ return failure(); StringRef tcName = parser.curToken.getSpelling(); + std::string cppOpName = (tcName + "Op").str(); LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing tc: " << tcName << "\n"); if (failed(parser.parseToken(Token::Kind::id, "expected id")) || failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) @@ -1404,7 +1406,7 @@ SmallVector perComprehensionStates; while (parser.curToken.isNot(Token::Kind::r_brace)) { perComprehensionStates.push_back(ComprehensionParsingState()); - if (failed(parseOneComprehension(tcName, tcName, + if (failed(parseOneComprehension(cppOpName, tcName, perComprehensionStates.back()))) return failure(); }; @@ -1418,16 +1420,16 @@ return failure(); } if (genODSDecl) { - printODS(os, tcName, tcName); + printODS(os, cppOpName, tcName); os << "\n"; } if (genODSImpl) { auto &state = perComprehensionStates.back(); std::string extraMethods; llvm::raw_string_ostream ss(extraMethods); - printReferenceIterators(ss, tcName, state); - printReferenceIndexingMaps(ss, tcName, state); - printRegionBuilder(ss, tcName, state); + printReferenceIterators(ss, cppOpName, state); + printReferenceIndexingMaps(ss, cppOpName, state); + printRegionBuilder(ss, cppOpName, state); ss.flush(); os << extraMethods << "\n"; } @@ -1442,18 +1444,31 @@ /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName) { - const char *header = R"FMT( def {0}Op : LinalgNamedStructured_Op<"{1}", [ + const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [ NInputs<{2}>, NOutputs<{3}>, NamedStructuredOpTraits]> { let arguments = (ins Variadic:$views); let results = (outs Variadic:$output_tensors); + let regions = (region SizedRegion<1>:$region); + let builders = [OpBuilder< + "Builder *b, OperationState &result, TypeRange outputTypes, " + # "ValueRange views", + [{{ + result.addOperands(views); + result.addTypes(outputTypes); + buildNamedStructuredOpRegion<{0}>( + *b, result, TypeRange(views), outputTypes); + }]> + ]; + let parser = [{ + return ::parseNamedStructuredOp<{0}>(parser, result); + }]; let extraClassDeclaration = [{{ llvm::Optional> referenceIterators(); llvm::Optional> referenceIndexingMaps(); - void regionBuilder(ArrayRef args); + static void regionBuilder(Block &block); }]; - let hasFolder = 1; })FMT"; unsigned nInputs = 0, nOutputs = 0; @@ -1468,7 +1483,8 @@ } /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. -void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId, +void TCParser::printReferenceIterators(llvm::raw_ostream &os, + StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceReferenceIteratorsFmt = R"FMT( @@ -1498,11 +1514,12 @@ }); ss.flush(); - os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr); + os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr); } /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. -void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, +void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, + StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceIndexingMapsFmt = R"FMT( @@ -1544,11 +1561,11 @@ }); mapsStringStream.flush(); - os << llvm::formatv(referenceIndexingMapsFmt, opId, dimsStr, mapsStr); + os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr); } /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. -void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId, +void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state) { unsigned count = state.orderedTensorArgs.size(); llvm::DenseMap subExprsMap; @@ -1570,15 +1587,17 @@ }); subExprsStringStream.flush(); const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});"; - os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs); + os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, + subExprs); subExprsMap[pTensorExpr] = count; } }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(ArrayRef args) { + void {0}::regionBuilder(Block &block) { using namespace edsc; using namespace intrinsics; + auto args = block.getArguments(); ValueHandle {1}; {2} (linalg_yield(ValueRange{ {3} })); @@ -1612,8 +1631,8 @@ expressionStringStream.flush(); yieldStringStream.flush(); - os << llvm::formatv(regionBuilderFmt, opId, valueHandleStr, expressionsStr, - yieldStr); + os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr, + expressionsStr, yieldStr); } /// Iterate over each Tensor Comprehension def.