diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -1,7 +1,38 @@ +# Declare a function to generate ODS with mlir-linalg-ods-gen +function(add_linalg_ods_gen tc_filename output_file) + set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename}) + set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.td) + set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.cpp.inc) + set_source_files_properties( + ${GEN_ODS_FILE} + PROPERTIES GENERATED TRUE) + set_source_files_properties( + ${GEN_CPP_FILE} + PROPERTIES GENERATED TRUE) + add_custom_command( + OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE} + COMMAND mlir-linalg-ods-gen -gen-ods-decl ${TC_SOURCE} > ${GEN_ODS_FILE} + COMMAND mlir-linalg-ods-gen -gen-impl ${TC_SOURCE} > ${GEN_CPP_FILE} + MAIN_DEPENDENCY + ${TC_SOURCE} + DEPENDS + mlir-linalg-ods-gen + VERBATIM) + add_custom_target( + MLIR${output_file}IncGen + DEPENDS + mlir-linalg-ods-gen + ${GEN_ODS_FILE} ${GEN_CPP_FILE}) +endfunction() + add_mlir_dialect(LinalgOps linalg) add_mlir_doc(LinalgDoc -gen-op-doc LinalgOps Dialects/) +add_linalg_ods_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps) set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td) +set(TABLEGEN_OUTPUT + ${TABLEGEN_OUTPUT} + MLIRLinalgNamedStructuredOpsIncGen) mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls) mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -0,0 +1,4 @@ +ods_def: +def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { + C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); +} diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -516,7 +516,8 @@ CPred<"$_self.cast().getRank() == " # rank>] >>; -class GenericOpBase : LinalgStructuredBase_Op { +class GenericOpBase : LinalgStructuredBase_Op]> { let arguments = (ins Variadic:$views, I64Attr:$args_in, I64Attr:$args_out, @@ -806,11 +807,22 @@ def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">; class LinalgNamedStructured_Op props> - : Op { + : LinalgStructuredBase_Op { string spec = ?; - let assemblyFormat = "`(` operands `)` attr-dict `:` " - "functional-type(operands, results)"; + // We cannot use an assemblyFormat atm because we need to hook in a custom- + // built implicit region from a static OpClass method. + // TODO: Revisit in the future if/when appropriate. + // let assemblyFormat = "`(` operands `)` attr-dict `:` " + // "functional-type(operands, results)"; + + // The parser needs to specialize on the OpType so it has to be auto-generated + // in the linalg-ods tool. + let printer = [{ return ::printNamedStructuredOp(p, *this); }]; + let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; + let hasFolder = 1; } +// This file is auto-generated from a tc specification. +include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td" + #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -64,7 +64,8 @@ "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ - Return the type of the input shape at the given index. + Return the `i`-th input shaped type, irrespective of buffer or tensor + type. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Return the subset of input operands that are of ranked tensor type. @@ -89,6 +90,10 @@ InterfaceMethod<[{ Return the type of the output buffer at the given index. }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Return the `i`-th output shaped type, irrespective of buffer or tensor + type. + }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Return the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" @@ -119,7 +120,8 @@ return it - getInputs().begin(); return llvm::None; } - /// Return the `i`-th input buffer type. + /// Return the `i`-th input shaped type, irrespective of buffer or tensor + /// type. ShapedType getInputShapedType(unsigned i) { return getInput(i).getType().template cast(); } @@ -344,6 +346,17 @@ } }; +/// This class provides the API for named Linalg StructuredOps. +template +class NamedStructuredOpTraits + : public OpTrait::TraitBase { +public: + llvm::Optional> referenceIterators(); + llvm::Optional> referenceIndexingMaps(); + std::function)> + emitScalarImplementation(); +}; + } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" @@ -30,6 +31,20 @@ using namespace mlir; using namespace mlir::linalg; +/// Forward declarations. +template +static void buildNamedStructuredOpRegion(Builder &builder, + OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes); +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); +template +static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); + /// Determines whether it is possible to fold it away in the parent Linalg op: /// /// ```mlir @@ -184,7 +199,14 @@ parser.getCurrentLocation(), result.operands); } -LogicalResult verifyBlockArgs(GenericOp op, Block &block) { +template +struct BlockArgsVerifier { + static LogicalResult verify(GenericOpType op, Block &block); +}; + +template +LogicalResult BlockArgsVerifier::verify(GenericOpType op, + Block &block) { auto nOperands = op.getNumOperands(); if (block.getNumArguments() != nOperands) return op.emitOpError("expected number of block arguments to match number " @@ -203,7 +225,9 @@ return success(); } -LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { +template <> +LogicalResult BlockArgsVerifier::verify(IndexedGenericOp op, + Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); auto nOperands = op.getNumOperands(); @@ -245,7 +269,7 @@ auto ®ion = op.region(); if (region.getBlocks().size() != 1) return op.emitOpError("expected region with 1 block"); - if (failed(verifyBlockArgs(op, region.getBlocks().front()))) + if (failed(BlockArgsVerifier::verify(op, region.front()))) return failure(); SmallVector indexingMaps; @@ -737,17 +761,18 @@ parser.resolveOperands(opInfo, types, loc, result.operands)); } -template -static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { - // The operand number and types must match the view element types. - auto nOutputs = genericOp.getNumOutputs(); +// Check the operand number and types must match the element types of the +// LinalgOp interface's shaped operands. +static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) { + auto nOutputs = linalgOpInterface.getNumOutputs(); if (op.getNumOperands() != nOutputs) return op.emitOpError("expected number of yield values (") << nOutputs << ") to match the number of operands of the enclosing " - << "linalg.generic op (" << op.getNumOperands() << ")"; + << "LinalgOp (" << op.getNumOperands() << ")"; for (unsigned i = 0; i != nOutputs; ++i) { - auto elementType = genericOp.getOutputShapedType(i).getElementType(); + auto elementType = + linalgOpInterface.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) return op.emitOpError("type of yield operand ") << (i + 1) << " (" << op.getOperand(i).getType() @@ -763,17 +788,10 @@ if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); - auto genericOp = dyn_cast(parentOp); - if (genericOp) - return verifyYield(op, genericOp); - - auto indexedGenericOp = dyn_cast(parentOp); - if (indexedGenericOp) - return verifyYield(op, indexedGenericOp); + if (auto linalgOp = dyn_cast(parentOp)) + return verifyYield(op, cast(parentOp)); - return op.emitOpError("expected '") - << GenericOp::getOperationName() << "' or '" - << IndexedGenericOp::getOperationName() << "' parent op"; + return op.emitOpError("expected parent op with LinalgOp interface"); } /////// Operations corresponding to library calls defined with Tablegen //////// @@ -1056,3 +1074,82 @@ return getResult(); return {}; } + +//===----------------------------------------------------------------------===// +// Auto-generated Linalg named ops. +//===----------------------------------------------------------------------===// + +template +void buildNamedStructuredOpRegion(Builder &builder, OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes) { + Region ®ion = *result.addRegion(); + Block *body = new Block(); + // TODO: atm all operands go through getElementTypeOrSelf, + // reconsider when we have evidence we need to. + for (auto t : operandTypes) + body->addArgument(getElementTypeOrSelf(t)); + for (auto t : tensorResultTypes) + body->addArgument(getElementTypeOrSelf(t)); + region.push_back(body); + + OpBuilder opBuilder(builder.getContext()); + opBuilder.setInsertionPointToStart(®ion.front()); + mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); + NamedStructuredOpType::regionBuilder(*body); +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p << op.getOperationName() << ' '; + p.printOptionalAttrDict(op.getAttrs()); + p << ' ' << op.getOperands(); + p << ": (" << op.getOperandTypes() << ")"; + auto outputTensorTypes = op.getResultTypes(); + if (!outputTensorTypes.empty()) + p << " -> (" << outputTensorTypes << ")"; +} + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { + SmallVector operandsInfo; + + // Optional attributes may be added. + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOperandList(operandsInfo)) + return failure(); + + SmallVector operandTypes; + if (parser.parseColon() || parser.parseLParen() || + parser.parseTypeList(operandTypes) || parser.parseRParen()) + return failure(); + + // Generic ops may specify that a subset of its outputs are tensors. Such + // outputs are specified in the result type. + SmallVector tensorResultTypes; + if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + return failure(); + + if (!tensorResultTypes.empty()) + result.addTypes(tensorResultTypes); + + buildNamedStructuredOpRegion( + parser.getBuilder(), result, operandTypes, tensorResultTypes); + + return parser.resolveOperands(operandsInfo, operandTypes, + parser.getCurrentLocation(), result.operands); +} + +template +static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { + return verifyGenericOp(op); +} + +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" + +// TODO: Determine whether we can generate the folders and verifiers. +LogicalResult BatchMatmulOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -121,8 +121,21 @@ } namespace { + +// Generic loop emitter, to be specialized on an op-per op basis. +// TODO: template -class LinalgScopedEmitter {}; +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + LinalgOpType linalgOp) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + llvm_unreachable("NYI"); + linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(), + ScopedContext::getLocation(), allIvs); + } +}; template class LinalgScopedEmitter { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -48,7 +48,7 @@ // ----- func @yield_parent(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}} + // expected-error @+1 {{op expected parent op with LinalgOp interface}} linalg.yield %arg0: memref(off + i)>> } @@ -91,7 +91,7 @@ // ----- func @generic_mismatched_num_returns(%arg0: memref) { - // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (0)}} + // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}} linalg.generic { args_in = 0, args_out = 1, @@ -114,6 +114,7 @@ iterator_types = ["parallel"] } %arg0 { ^bb(%i : i32): + linalg.yield %i : i32 }: memref } @@ -128,6 +129,7 @@ iterator_types = ["parallel"] } %arg0 { ^bb(%i : i32): + linalg.yield %i : i32 }: memref<1xi32> } @@ -188,7 +190,8 @@ // ----- func @generic_empty_region(%arg0: memref) { - // expected-error @+1 {{op expected region with 1 block}} + %f0 = constant 0.0: f32 + // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}} linalg.generic { args_in = 1, args_out = 1, @@ -196,7 +199,23 @@ iterator_types = [] } %arg0, %arg0 { ^bb1: + linalg.yield %f0: f32 ^bb2: + linalg.yield %f0: f32 + }: memref, memref +} + +// ----- + +func @generic_empty_region(%arg0: memref) { + %f0 = constant 0.0: f32 + // expected-error @+1 {{linalg.generic' op expected region with 1 block}} + linalg.generic { + args_in = 1, + args_out = 1, + indexing_maps = [ affine_map<() -> (0)> ], + iterator_types = [] + } %arg0, %arg0 { }: memref, memref } @@ -210,7 +229,8 @@ indexing_maps = [ affine_map<() -> (0)> ], iterator_types = [] } %arg0 { - ^bb: + ^bb(%f: f32, %g: f32): + linalg.yield %f: f32 }: memref } @@ -225,6 +245,7 @@ iterator_types = [] } %arg0 { ^bb(%i: i1): + linalg.yield %i : i1 }: memref } @@ -239,6 +260,7 @@ iterator_types = ["parallel"] } %arg0 { ^bb(%f: f32): + linalg.yield %f : f32 }: memref } @@ -253,6 +275,7 @@ iterator_types = ["parallel"] } %arg0 { ^bb(%i: f64, %f: f32): + linalg.yield %f: f32 }: memref } @@ -267,6 +290,7 @@ iterator_types = ["parallel"] } %arg0 { ^bb(%i: index, %f: i1): + linalg.yield %i: index }: memref } @@ -304,7 +328,7 @@ // ----- func @indexed_generic_result_count(%arg0: memref) { - // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing linalg.generic op (2)}} + // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -349,6 +373,38 @@ // ----- +func @generic_result_tensor_type(%arg0: memref(off + i)>>) { + // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ affine_map<(i) -> (i)> ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: memref(off + i)>> -> f32 +} + +// ----- + +func @generic(%arg0: memref) { + // expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}} + linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ affine_map<(i) -> (i)> ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%0: i4) : + %1 = std.addf %0, %0: i4 + } : memref + return +} + +// ----- + func @generic_result_0_element_type(%arg0: memref) { // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}} linalg.dot(%arg0, %arg0): memref, memref @@ -420,3 +476,11 @@ memref, memref<2x3xf32>, memref return } + +// ----- + +func @named_ops(%a3: memref, %b3: memref, %c3: memref) { + // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref'}} + linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> () + return +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -620,3 +620,16 @@ // CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] // CHECK-SAME: memref into memref + + +// TODO: Return tensors need a semantics convention update. +func @named_ops(%a3: memref, %b3: memref, %c3: memref, + %ta3: tensor, %tb3: tensor, %tc3: tensor) { + linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> () + linalg.batch_matmul %ta3, %tb3, %c3 : (tensor, tensor, memref) -> () + return +} +// CHECK-LABEL: func @named_ops +// CHECK: linalg.batch_matmul +// CHECK: linalg.batch_matmul + diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt --- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt +++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt @@ -1,4 +1,9 @@ set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td) +# Silent dependency at a distance .. LinalgStructuredOps.td includes +# an auto generated .td tracked by target MLIRLinalgNamedStructuredOpsIncGen +set(TABLEGEN_OUTPUT + ${TABLEGEN_OUTPUT} + MLIRLinalgNamedStructuredOpsIncGen) mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen) @@ -7,5 +12,10 @@ add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen) set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td) +# Silent dependency at a distance .. LinalgStructuredOps.td includes +# an auto generated .td tracked by target MLIRLinalgNamedStructuredOpsIncGen +set(TABLEGEN_OUTPUT + ${TABLEGEN_OUTPUT} + MLIRLinalgNamedStructuredOpsIncGen) mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters) add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -1,75 +1,77 @@ // RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS // RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL -// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 -test-emit-include-td-header \ -// RUN: | mlir-tblgen -gen-op-decls -I %S/../../include - -// ODS-LABEL: def matvecOp : LinalgNamedStructured_Op<"matvec", [ -// ODS-NEXT: NInputs<2>, -// ODS-NEXT: NOutputs<1>, -// ODS-NEXT: NamedStructuredOpTraits]> +// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [ +// ODS-NEXT: NInputs<2> +// ODS-NEXT: NOutputs<1> +// ODS-NEXT: NamedStructuredOpTraits +// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: matvec::referenceIterators() { +// IMPL-LABEL: Test1Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: matvec::referenceIndexingMaps() { -// IMPL: AffineMap::get(2, 0, {d0, d1}), -// IMPL-NEXT: AffineMap::get(2, 0, {d1}), -// IMPL-NEXT: AffineMap::get(2, 0, {d0}) }; +// IMPL: Test1Op::referenceIndexingMaps() { +// IMPL: AffineMap::get(2, 0, {d0, d1}, context), +// IMPL-NEXT: AffineMap::get(2, 0, {d1}, context), +// IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }; // -// IMPL: matvec::regionBuilder(ArrayRef args) { +// IMPL: Test1Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { +ods_def : +def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { C(m) = std_addf(std_mulf(A(m, k), B(k))); } -// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [ -// ODS-NEXT: NInputs<2>, -// ODS-NEXT: NOutputs<1>, -// ODS-NEXT: NamedStructuredOpTraits]> +// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [ +// ODS-NEXT: NInputs<2> +// ODS-NEXT: NOutputs<1> +// ODS-NEXT: NamedStructuredOpTraits +// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: matmul::referenceIterators() { +// IMPL-LABEL: Test2Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: matmul::referenceIndexingMaps() { -// IMPL: AffineMap::get(3, 0, {d0, d2}), -// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}), -// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}) }; +// IMPL: Test2Op::referenceIndexingMaps() { +// IMPL: AffineMap::get(3, 0, {d0, d2}, context), +// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}, context), +// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }; // -// IMPL: matmul::regionBuilder(ArrayRef args) { +// IMPL: Test2Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { +ods_def : +def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } -// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ -// ODS-NEXT: NInputs<2>, -// ODS-NEXT: NOutputs<1>, -// ODS-NEXT: NamedStructuredOpTraits]> +// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [ +// ODS-NEXT: NInputs<2> +// ODS-NEXT: NOutputs<1> +// ODS-NEXT: NamedStructuredOpTraits +// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // -// IMPL-LABEL: batchmatmul::referenceIterators() { +// IMPL-LABEL: Test3Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: batchmatmul::referenceIndexingMaps() { -// IMPL: AffineMap::get(4, 0, {d0, d1, d3}), -// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}), -// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}) }; +// IMPL: Test3Op::referenceIndexingMaps() { +// IMPL: AffineMap::get(4, 0, {d0, d1, d3}, context), +// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}, context), +// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }; // -// IMPL: batchmatmul::regionBuilder(ArrayRef args) { +// IMPL: Test3Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -// TBLGEN: batchmatmulOp -def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { +ods_def : +def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -90,6 +90,7 @@ // Keywords. kw_def, FIRST_KEYWORD = kw_def, + kw_ods_def, kw_floordiv, kw_ceildiv, kw_mod, @@ -289,6 +290,7 @@ StringRef str(tokStart, curPtr - tokStart); Token::Kind kind = llvm::StringSwitch(str) .Case("def", Token::Kind::kw_def) + .Case("ods_def", Token::Kind::kw_ods_def) .Case("floordiv", Token::Kind::kw_floordiv) .Case("ceildiv", Token::Kind::kw_ceildiv) .Case("mod", Token::Kind::kw_mod) @@ -896,7 +898,8 @@ TensorExpr(StringRef name, SmallVectorImpl> &&exprs, ArrayRef reductionDims) - : Expression(Kind::TensorExpr), opId(name), expressions(std::move(exprs)), + : Expression(Kind::TensorExpr), operationName(name), + expressions(std::move(exprs)), reductionDimensions(reductionDims.begin(), reductionDims.end()) {} static bool classof(const Expression *e) { @@ -904,7 +907,7 @@ } bool operator==(const TensorExpr &other) const { - if (opId != other.opId) + if (operationName != other.operationName) return false; if (expressions.size() != other.expressions.size()) return false; @@ -922,7 +925,7 @@ template void visit(Lambda callback) const; - StringRef opId; + StringRef operationName; SmallVector, 4> expressions; SetVector reductionDimensions; }; @@ -988,22 +991,22 @@ /// When `gen-impl` is used, this prints the C++ implementation for the extra /// methods defined in ODS (referenceIterators, referenceIndexingMaps and /// regionBuilder). - LogicalResult parseAndEmitTCDef(llvm::raw_ostream &os); + LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os); /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName); /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. - void printReferenceIterators(llvm::raw_ostream &os, StringRef opId, + void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. - void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, + void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. - void printRegionBuilder(llvm::raw_ostream &os, StringRef opId, + void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); private: @@ -1346,7 +1349,7 @@ return success(); } -/// Parse and print the information for a TC def. +/// Parse and print the information for a ODS def. /// /// tensor-def-list ::= tensor-def (`,` tensor-def )* /// @@ -1355,16 +1358,29 @@ /// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` /// `{` comprehension-list `}` /// +/// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def +/// /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e. /// contain only expressions involving symbols and constants), but can /// otherwise contain arbitrary affine expressions. -LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) { +LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { + if (failed(parser.parseToken(Token::Kind::kw_ods_def, + "expected 'ods_def' to define a TC ODS")) || + failed(parser.parseToken(Token::Kind::lt, "expected '<'"))) + return failure(); + StringRef cppOpName = parser.curToken.getSpelling(); + LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n"); + + if (failed(parser.parseToken(Token::Kind::id, "expected id")) || + failed(parser.parseToken(Token::Kind::gt, "expected '>'")) || + failed(parser.parseToken(Token::Kind::colon, "expected ':'"))) + return failure(); if (failed(parser.parseToken(Token::Kind::kw_def, "expected 'def' to define a TC"))) return failure(); StringRef tcName = parser.curToken.getSpelling(); - LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing tc: " << tcName << "\n"); + LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n"); if (failed(parser.parseToken(Token::Kind::id, "expected id")) || failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) return failure(); @@ -1404,7 +1420,7 @@ SmallVector perComprehensionStates; while (parser.curToken.isNot(Token::Kind::r_brace)) { perComprehensionStates.push_back(ComprehensionParsingState()); - if (failed(parseOneComprehension(tcName, tcName, + if (failed(parseOneComprehension(cppOpName, tcName, perComprehensionStates.back()))) return failure(); }; @@ -1418,16 +1434,16 @@ return failure(); } if (genODSDecl) { - printODS(os, tcName, tcName); + printODS(os, cppOpName, tcName); os << "\n"; } if (genODSImpl) { auto &state = perComprehensionStates.back(); std::string extraMethods; llvm::raw_string_ostream ss(extraMethods); - printReferenceIterators(ss, tcName, state); - printReferenceIndexingMaps(ss, tcName, state); - printRegionBuilder(ss, tcName, state); + printReferenceIterators(ss, cppOpName, state); + printReferenceIndexingMaps(ss, cppOpName, state); + printRegionBuilder(ss, cppOpName, state); ss.flush(); os << extraMethods << "\n"; } @@ -1442,18 +1458,32 @@ /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName) { - const char *header = R"FMT( def {0}Op : LinalgNamedStructured_Op<"{1}", [ + const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [ NInputs<{2}>, NOutputs<{3}>, - NamedStructuredOpTraits]> { + NamedStructuredOpTraits, + SingleBlockImplicitTerminator<"YieldOp">]> { let arguments = (ins Variadic:$views); let results = (outs Variadic:$output_tensors); + let regions = (region SizedRegion<1>:$region); + let builders = [OpBuilder< + "Builder *b, OperationState &result, TypeRange outputTypes, " + # "ValueRange views", + [{{ + result.addOperands(views); + result.addTypes(outputTypes); + buildNamedStructuredOpRegion<{0}>( + *b, result, TypeRange(views), outputTypes); + }]> + ]; + let parser = [{ + return ::parseNamedStructuredOp<{0}>(parser, result); + }]; let extraClassDeclaration = [{{ llvm::Optional> referenceIterators(); llvm::Optional> referenceIndexingMaps(); - void regionBuilder(ArrayRef args); + static void regionBuilder(Block &block); }]; - let hasFolder = 1; })FMT"; unsigned nInputs = 0, nOutputs = 0; @@ -1468,7 +1498,8 @@ } /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. -void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId, +void TCParser::printReferenceIterators(llvm::raw_ostream &os, + StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceReferenceIteratorsFmt = R"FMT( @@ -1498,11 +1529,12 @@ }); ss.flush(); - os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr); + os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr); } /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. -void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, +void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, + StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceIndexingMapsFmt = R"FMT( @@ -1527,7 +1559,7 @@ orderedUses[it.second] = it.first; llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { assert(u.indexingMap); - const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1})"; + const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)"; if (u.indexingMap.isEmpty()) { mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context"); return; @@ -1544,11 +1576,11 @@ }); mapsStringStream.flush(); - os << llvm::formatv(referenceIndexingMapsFmt, opId, dimsStr, mapsStr); + os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr); } /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. -void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId, +void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state) { unsigned count = state.orderedTensorArgs.size(); llvm::DenseMap subExprsMap; @@ -1570,15 +1602,17 @@ }); subExprsStringStream.flush(); const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});"; - os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs); + os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, + subExprs); subExprsMap[pTensorExpr] = count; } }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(ArrayRef args) { + void {0}::regionBuilder(Block &block) { using namespace edsc; using namespace intrinsics; + auto args = block.getArguments(); ValueHandle {1}; {2} (linalg_yield(ValueRange{ {3} })); @@ -1612,8 +1646,8 @@ expressionStringStream.flush(); yieldStringStream.flush(); - os << llvm::formatv(regionBuilderFmt, opId, valueHandleStr, expressionsStr, - yieldStr); + os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr, + expressionsStr, yieldStr); } /// Iterate over each Tensor Comprehension def. @@ -1621,7 +1655,7 @@ Parser &parser) { while (parser.curToken.getKind() != Token::Kind::eof) { TCParser tcParser(parser); - if (failed(tcParser.parseAndEmitTCDef(os))) + if (failed(tcParser.parseAndEmitODSDef(os))) return failure(); } return success();