diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h copy from mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h copy to mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h @@ -1,15 +1,14 @@ -//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// +//===- FoldedIntrinsics.h - MLIR EDSC Intrinsics for Linalg -----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ -#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ +#ifndef MLIR_DIALECT_LINALG_EDSC_FOLDEDINTRINSICS_H_ +#define MLIR_DIALECT_LINALG_EDSC_FOLDEDINTRINSICS_H_ -#include "mlir/Dialect/Linalg/EDSC/Builders.h" -#include "mlir/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Transforms/FoldUtils.h" namespace mlir { @@ -24,19 +23,7 @@ ScopedContext::getLocation(), args...)); } -namespace intrinsics { -using linalg_copy = OperationBuilder; -using linalg_dot = OperationBuilder; -using linalg_fill = OperationBuilder; -using linalg_matmul = OperationBuilder; -using linalg_matvec = OperationBuilder; -using linalg_range = ValueBuilder; -using linalg_reshape = ValueBuilder; -using linalg_slice = ValueBuilder; -using linalg_yield = OperationBuilder; - -} // namespace intrinsics } // namespace edsc } // namespace mlir -#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ +#endif // MLIR_DIALECT_LINALG_EDSC_FOLDEDINTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -8,22 +8,11 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ #define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ -#include "mlir/Dialect/Linalg/EDSC/Builders.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" namespace mlir { namespace edsc { - -template -ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) { - return folder ? ValueHandle(folder->create(ScopedContext::getBuilder(), - ScopedContext::getLocation(), - args...)) - : ValueHandle(ScopedContext::getBuilder().create( - ScopedContext::getLocation(), args...)); -} - namespace intrinsics { using linalg_copy = OperationBuilder; using linalg_dot = OperationBuilder; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -0,0 +1,3 @@ +def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { + C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); +} diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -523,7 +523,6 @@ AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types, OptionalAttr:$doc, - OptionalAttr:$fun, OptionalAttr:$library_call); let results = (outs Variadic:$output_tensors); let regions = (region AnyRegion:$region); @@ -531,7 +530,7 @@ SmallVector linalgTraitAttrNames() { return SmallVector{ getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(), - getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(), + getIndexingMapsAttrName(), getLibraryCallAttrName(), getIteratorTypesAttrName() }; } @@ -540,12 +539,6 @@ unsigned getNumOutputs() { return args_out().getSExtValue(); } - FuncOp getFunction() { - auto moduleOp = getParentOfType(); - return fun().hasValue() ? - moduleOp.lookupSymbol(fun().getValue()) : FuncOp(); - } - StringRef getLibraryCallName() { return library_call().hasValue() ? library_call().getValue() : ""; } @@ -581,13 +574,6 @@ - args_in: an I64Attr representing the number of input (readonly) views - args_out: an I64Attr representing the number of output (readwrite) views - doc [optional]: a documentation string - - fun: a FlatSymbolRefAttr that must resolve to an existing function - symbol. To support inplace updates in a generic fashion, the signature - of the function must be: - ``` - fun([input views element types], [output views element types]) - -> ([output views element types]) - ``` - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input and output view. Such AffineMapAttr specifies the mapping between the loops and the indexing within each view. @@ -604,11 +590,6 @@ Example: Defining a #matmul_trait attribute in MLIR can be done as follows: ```mlir - func @fma(%a: f32, %b: f32, %c: f32) -> f32 { - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 - return %e: f32 - } #matmul_accesses = [ (m, n, k) -> (m, k), (m, n, k) -> (k, n), @@ -616,7 +597,6 @@ ] #matmul_trait = { doc = "C(m, n) += A(m, k) * B(k, n)", - fun = @fma, indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], @@ -626,10 +606,14 @@ And can be reused in multiple places as: ```mlir - linalg.generic #matmul_trait %A, %B, %C [other-attributes] : - memref, - memref, - memref + linalg.generic #matmul_trait %A, %B, %C [other-attributes] { + (%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + linalg_yield %e : f32 + } : memref, + memref, + memref ``` This may lower to either: @@ -649,9 +633,9 @@ %a = load %A[%m, %k] : memref %b = load %B[%k, %n] : memref %c = load %C[%m, %n] : memref - %d = call @func_of_elements(%a, %b, %c) - : (f32, f32, f32) -> (f32) - store %d, %C[%m, %n] : memref + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + store %e, %C[%m, %n] : memref } } } @@ -662,7 +646,7 @@ mixing input and output ranked tensor values with input and output memrefs. ```mlir - %C = linalg.generic #trait_attribute %A, %B {other-attributes} : + %C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} : tensor, memref -> (tensor) @@ -708,13 +692,6 @@ - args_in: an I64Attr representing the number of input (readonly) views - args_out: an I64Attr representing the number of output (readwrite) views - doc [optional]: a documentation string - - fun: a FlatSymbolRefAttr that must resolve to an existing function - symbol. To support inplace updates in a generic fashion, the signature - of the function must be: - ``` - fun([index types of induction variables], [input views element types], - [output views element types]) -> ([output views element types]) - ``` - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input and output view. Such AffineMapAttr specifies the mapping between the loops and the indexing within each view. @@ -732,15 +709,6 @@ Defining a #matmul_trait attribute in MLIR can be done as follows: ```mlir - func @fma(%offset_m: index, %offset_n: index, %offset_k: index, - %a: f32, %b: f32, %c: f32) - -> f32 - { - "some_optional_condition"(%offset_m, %offset_n, %offset_k) - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 - return %e: f32 - } #matmul_accesses = [ (m, n, k) -> (m, k), (m, n, k) -> (k, n), @@ -748,7 +716,6 @@ ] #matmul_trait = { doc = "C(m, n) += A(m, k) * B(k, n)", - fun = @fma, indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], @@ -759,10 +726,16 @@ And can be reused in multiple places as: ```mlir - linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] : - memref, - memref, - memref + linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] { + (%offset_m: index, %offset_n: index, %offset_k: index, + %a: f32, %b: f32, %c: f32) : + "some_optional_computation"(%offset_m, %offset_n, %offset_k) + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + linalg_yield %e : f32 + } : memref, + memref, + memref ``` This may lower to either: @@ -784,8 +757,9 @@ %a = load %A[%m, %k] : memref %b = load %B[%k, %n] : memref %c = load %C[%m, %n] : memref - %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c) - : (index, index, index, f32, f32, f32) -> (f32) + "some_optional_computation"(%m, %n, %k) + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 store %d, %C[%m, %n] : memref } } @@ -832,11 +806,22 @@ def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">; class LinalgNamedStructured_Op props> - : Op { + : LinalgStructuredBase_Op { string spec = ?; - let assemblyFormat = "`(` operands `)` attr-dict `:` " - "functional-type(operands, results)"; + // We cannot use an assemblyFormat atm because we need to hook in a custom- + // built implicit region from a static OpClass method. + // TODO(ntv): Revisit in the future if/when appropriate. + // let assemblyFormat = "`(` operands `)` attr-dict `:` " + // "functional-type(operands, results)"; + + // The parser needs to specialize on the OpType so it has to be auto-generated + // in the linalg-ods tool. + let printer = [{ return ::printNamedStructuredOp(p, *this); }]; + let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; + let hasFolder = 1; } +// This file is auto-generated from a tc specification. +include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td" + #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -64,7 +64,8 @@ "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ - Return the type of the input shape at the given index. + Return the `i`-th input shaped type, irrespective of buffer of tensor + type. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Return the subset of input operands that are of ranked tensor type. @@ -89,6 +90,10 @@ InterfaceMethod<[{ Return the type of the output buffer at the given index. }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Return the `i`-th output shaped type, irrespective of buffer of tensor + type. + }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Return the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" @@ -119,7 +120,8 @@ return it - getInputs().begin(); return llvm::None; } - /// Return the `i`-th input buffer type. + /// Return the `i`-th input shaped type, irrespective of buffer of tensor + /// type. ShapedType getInputShapedType(unsigned i) { return getInput(i).getType().template cast(); } @@ -344,6 +346,17 @@ } }; +/// This class provides the API for named Linalg StructuredOps. +template +class NamedStructuredOpTraits + : public OpTrait::TraitBase { +public: + llvm::Optional> referenceIterators(); + llvm::Optional> referenceIndexingMaps(); + std::function)> + emitScalarImplementation(); +}; + } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -66,10 +66,6 @@ /// string of the structured op. constexpr StringRef getDocAttrName() { return "doc"; } -/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the -/// MLIR function that implements the body of the structured op. -constexpr StringRef getFunAttrName() { return "fun"; } - /// Attribute name for the StrArrayAttr which encodes the external library /// function that implements the structured op. constexpr StringRef getLibraryCallAttrName() { return "library_call"; } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/Builders.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -177,7 +178,6 @@ builder.getAffineMapArrayAttr(maps), builder.getStrArrayAttr(iteratorStrTypes), StringAttr() /*doc*/, - FlatSymbolRefAttr() /*fun*/, StringAttr() /*library_call*/ /* TODO: other attributes in op */ ) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" @@ -30,6 +31,20 @@ using namespace mlir; using namespace mlir::linalg; +/// Forward declarations. +template +static void buildNamedStructuredOpRegion(Builder &builder, + OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes); +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); +template +static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); + /// Determines whether it is possible to fold it away in the parent Linalg op: /// /// ```mlir @@ -133,10 +148,11 @@ attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); - p << op.getOperationName() << " " << dictAttr << " " << op.getOperands(); + p << op.getOperationName() << " " << dictAttr; + p.printOptionalAttrDict(op.getAttrs(), attrNames); + p << " " << op.getOperands(); if (!op.region().empty()) p.printRegion(op.region()); - p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); auto outputTensorTypes = op.getResultTypes(); if (!outputTensorTypes.empty()) @@ -156,21 +172,21 @@ // The name is unimportant as we will overwrite result.attributes. // The core linalg traits must contain the information necessary to pass the // verifier. - if (parser.parseAttribute(dictAttr, "_", result.attributes) || - parser.parseOperandList(operandsInfo)) + if (parser.parseAttribute(dictAttr, "_", result.attributes)) return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); + // Optional attributes may be added. + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseOperandList(operandsInfo)) + return failure(); + Region ®ion = *result.addRegion(); SmallVector operandTypes, regionTypes; - // Optional attributes may be added. - // Either Optional getFunAttrName() attribute or region must be specified. - if (!dictAttr.get(getFunAttrName()) && - parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes)) + if (parser.parseRegion(region, regionOperandsInfo, regionTypes)) return failure(); - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonTypeList(operandTypes)) + if (parser.parseColonTypeList(operandTypes)) return failure(); // Generic ops may specify that a subset of its outputs are tensors. Such // outputs are specified in the result type. @@ -184,9 +200,13 @@ } template -static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); +struct BlockArgsVerifier { + static LogicalResult verify(GenericOpType op, Block &block); +}; -template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { +template +LogicalResult BlockArgsVerifier::verify(GenericOpType op, + Block &block) { auto nOperands = op.getNumOperands(); if (block.getNumArguments() != nOperands) return op.emitOpError("expected number of block arguments to match number " @@ -205,7 +225,9 @@ return success(); } -template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { +template <> +LogicalResult BlockArgsVerifier::verify(IndexedGenericOp op, + Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); auto nOperands = op.getNumOperands(); @@ -234,81 +256,6 @@ return success(); } -template -static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); - -template -static LogicalResult verifyFuncArgsGeneric(GenericOpType op, - FunctionType funType) { - auto res = verifyFuncArgs(op, funType); - if (failed(res)) - return res; - - auto nInputs = op.getNumInputs(); - auto nOutputs = op.getNumOutputs(); - // linalg.generic output element types are exactly the function results. - for (unsigned idx = 0; idx < nOutputs; ++idx) { - ShapedType shapedType = op.getShapedType(nInputs + idx); - if (funType.getResult(idx) != shapedType.getElementType()) - return op.emitOpError("expected function result ") - << (idx + 1) << " of the same type as elemental type " - << shapedType.getElementType() << " of output " << (idx + 1); - } - return success(); -} - -template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nOperands = op.getNumOperands(); - if (funType.getNumInputs() != nOperands) - return op.emitOpError( - "expected function arguments to match number of operands"); - if (funType.getNumResults() != op.getNumOutputs()) - return op.emitOpError("expected function results(") - << funType.getNumResults() << ") to match number of outputs(" - << op.getNumOutputs() << ")"; - - // linalg.generic operands element types are exactly the first function - // arguments. - for (unsigned idx = 0; idx < nOperands; ++idx) { - ShapedType shapedType = op.getShapedType(idx); - if (funType.getInput(idx) != shapedType.getElementType()) - return op.emitOpError("expected function argument ") - << (idx + 1) << " of the same type as elemental type " - << shapedType.getElementType() << " of operand " << (idx + 1); - } - - return success(); -} - -template <> -LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { - auto nLoops = op.getNumLoops(); - auto nOutputs = op.getNumOutputs(); - auto nOperands = op.getNumOperands(); - if (funType.getNumInputs() != nOperands + nLoops) - return op.emitOpError("expected function arguments to match number of " - "loops + number of operands"); - if (funType.getNumResults() != nOutputs) - return op.emitOpError( - "expected function results to match number of outputs"); - for (unsigned i = 0; i < nLoops; ++i) - if (!funType.getInput(i).isIndex()) - return op.emitOpError("expected function argument ") - << (i + 1) << " to be an index"; - - // linalg.generic operands element types are exactly the first function - // arguments. - for (unsigned idx = 0; idx < nOperands; ++idx) { - ShapedType shapedType = op.getShapedType(idx); - if (funType.getInput(idx + nLoops) != shapedType.getElementType()) - return op.emitOpError("expected function argument ") - << (idx + nLoops + 1) << " of the same type as elemental type " - << shapedType.getElementType() << " of input " << (idx + 1); - } - - return success(); -} - template static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); @@ -320,20 +267,11 @@ << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); - auto funOp = op.getFunction(); - auto funType = funOp ? funOp.getType() : FunctionType(); - if (!region.empty()) { - if (region.getBlocks().size() != 1) - return op.emitOpError("expected region with 1 block"); - if (failed(verifyBlockArgs(op, region.getBlocks().front()))) - return failure(); - } else { - if (!funOp || !funOp.getType()) - return op.emitOpError( - "expected function attribute to refer to a defined symbol"); - if (failed(verifyFuncArgsGeneric(op, funType))) - return failure(); - } + if (region.getBlocks().size() != 1) + return op.emitOpError("expected region with 1 block"); + if (failed(BlockArgsVerifier::verify( + op, region.getBlocks().front()))) + return failure(); SmallVector indexingMaps; indexingMaps.reserve(op.indexing_maps().size()); @@ -824,17 +762,17 @@ parser.resolveOperands(opInfo, types, loc, result.operands)); } -template -static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { +static LogicalResult verifyYield(YieldOp op, LinalgOp linalgOpInterface) { // The operand number and types must match the view element types. - auto nOutputs = genericOp.getNumOutputs(); + auto nOutputs = linalgOpInterface.getNumOutputs(); if (op.getNumOperands() != nOutputs) return op.emitOpError("expected number of yield values (") << nOutputs << ") to match the number of operands of the enclosing " - << "linalg.generic op (" << op.getNumOperands() << ")"; + << "LinalgOp (" << op.getNumOperands() << ")"; for (unsigned i = 0; i != nOutputs; ++i) { - auto elementType = genericOp.getOutputShapedType(i).getElementType(); + auto elementType = + linalgOpInterface.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) return op.emitOpError("type of yield operand ") << (i + 1) << " (" << op.getOperand(i).getType() @@ -850,17 +788,10 @@ if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); - auto genericOp = dyn_cast(parentOp); - if (genericOp) - return verifyYield(op, genericOp); + if (auto linalgOp = dyn_cast(parentOp)) + return verifyYield(op, cast(parentOp)); - auto indexedGenericOp = dyn_cast(parentOp); - if (indexedGenericOp) - return verifyYield(op, indexedGenericOp); - - return op.emitOpError("expected '") - << GenericOp::getOperationName() << "' or '" - << IndexedGenericOp::getOperationName() << "' parent op"; + return op.emitOpError("expected parent op with LinalgOp interface"); } /////// Operations corresponding to library calls defined with Tablegen //////// @@ -1143,3 +1074,92 @@ return getResult(); return {}; } + +//===----------------------------------------------------------------------===// +// Auto-generated Linalg named ops. +//===----------------------------------------------------------------------===// + +template +void buildNamedStructuredOpRegion(Builder &builder, OperationState &result, + TypeRange operandTypes, + TypeRange tensorResultTypes) { + auto *op = Operation::create(builder.getUnknownLoc(), + OperationName("fake_op", builder.getContext()), + ArrayRef{}, ArrayRef{}, + ArrayRef{}, ArrayRef{}, + /*numRegions=*/1, + /*resizableOperandList=*/false); + std::unique_ptr> guard( + (int *)1, [&op](int *) { op->destroy(); }); + + Region &bodyRegion = op->getRegion(0); + Block *body = new Block(); + // TODO(ntv): atm all operands go through getElementTypeOrSelf, + // reconsider when we have evidence we need to. + for (auto t : operandTypes) + body->addArgument(getElementTypeOrSelf(t)); + for (auto t : tensorResultTypes) + body->addArgument(getElementTypeOrSelf(t)); + bodyRegion.push_back(body); + OpBuilder opBuilder(bodyRegion); + mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); + NamedStructuredOpType::regionBuilder(*body); + + // Steal the region and let op be destroyed. + result.addRegion()->takeBody(bodyRegion); +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p << op.getOperationName() << " "; + p.printOptionalAttrDict(op.getAttrs()); + p << "(" << op.getOperands() << ")"; + p << ": (" << op.getOperandTypes() << ")"; + auto outputTensorTypes = op.getResultTypes(); + if (!outputTensorTypes.empty()) + p << " -> (" << outputTensorTypes << ")"; +} + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { + SmallVector operandsInfo; + + // Optional attributes may be added. + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseLParen() || + parser.parseOperandList(operandsInfo) || parser.parseRParen()) + return failure(); + + SmallVector operandTypes; + if (parser.parseColon() || parser.parseLParen() || + parser.parseTypeList(operandTypes) || parser.parseRParen()) + return failure(); + + // Generic ops may specify that a subset of its outputs are tensors. Such + // outputs are specified in the result type. + SmallVector tensorResultTypes; + if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + return failure(); + + if (!tensorResultTypes.empty()) + result.addTypes(tensorResultTypes); + + buildNamedStructuredOpRegion( + parser.getBuilder(), result, operandTypes, tensorResultTypes); + + return parser.resolveOperands(operandsInfo, operandTypes, + parser.getCurrentLocation(), result.operands); +} + +template +static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { + return verifyGenericOp(op); +} + +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" + +// TODO(ntv): Determine whether we can generate the folders and verifiers. +LogicalResult batchmatmulOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -13,7 +13,7 @@ #include "PassDetail.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -382,8 +382,7 @@ // - only handle ops that use regions for specifying the scalar operations. if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || - producerOp.getNumParallelLoops() != producerOp.getNumLoops() || - producerOp.fun() || consumerOp.fun()) + producerOp.getNumParallelLoops() != producerOp.getNumLoops()) return false; // Get the consumer index map. The number of results of the consumer index map @@ -472,7 +471,6 @@ b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), /*doc=*/nullptr, - /*fun=*/nullptr, /*library_call=*/nullptr); // Build the region of the fused op. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -8,7 +8,8 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -121,8 +122,19 @@ } namespace { + template -class LinalgScopedEmitter {}; +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + LinalgOpType linalgOp) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + llvm_unreachable("NYI"); + linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(), + ScopedContext::getLocation(), allIvs); + } +}; template class LinalgScopedEmitter { @@ -400,21 +412,6 @@ indexedValues[nInputs + i] = std_load(output, indexing); } - auto funcOp = genericOp.getFunction(); - if (funcOp) { - // 2. Emit call. - Operation *callOp = std_call(funcOp, indexedValues); - assert(callOp->getNumResults() == genericOp.getNumOutputs()); - - // 3. Emit std_store. - for (unsigned i = 0; i < nOutputs; ++i) { - Value output = genericOp.getOutputBuffer(i); - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), output, indexing); - } - return; - } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. // 3. Emit std_store. @@ -495,20 +492,6 @@ indexedValues[nLoops + nInputs + i] = std_load(output, indexing); } - if (auto funcOp = indexedGenericOp.getFunction()) { - // 2. Emit call. - Operation *callOp = std_call(funcOp, indexedValues); - assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs()); - - // 3. Emit std_store. - for (unsigned i = 0; i < nOutputs; ++i) { - Value output = indexedGenericOp.getOutputBuffer(i); - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), output, indexing); - } - return; - } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. // 3. Emit std_store. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -12,7 +12,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -12,7 +12,8 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -48,206 +48,137 @@ // ----- func @yield_parent(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}} + // expected-error @+1 {{op expected parent op with LinalgOp interface}} linalg.yield %arg0: memref(off + i)>> } // ----- func @generic_at_least_2_operands(%arg0: memref) { - // expected-error @+1 {{op expected 2 or more operands}} + // expected-error @+6 {{expected '{' to begin a region}} linalg.generic { args_in = 1, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = [] - } %arg0: memref + } %arg0 : memref } // ----- -func @generic_exactly_2_views(%arg0: memref) { - // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}} +func @generic_at_least_2_operands(%arg0: memref) { + // expected-error @+1 {{op expected 2 or more operands}} linalg.generic { args_in = 1, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = [] - } %arg0, %arg0, %arg0: memref, memref, memref + } %arg0 {} : memref } // ----- -func @generic_undefined_fun(%arg0: memref) { - // expected-error @+1 {{op expected function attribute to refer to a defined symbol}} +func @generic_exactly_2_views(%arg0: memref) { + // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}} linalg.generic { args_in = 1, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = [] - } %arg0, %arg0: memref, memref + } %arg0, %arg0, %arg0 {}: memref, memref, memref } // ----- -func @foo() { return } - -func @generic_mismatched_num_arguments(%arg0: memref) { - // expected-error @+1 {{op expected function arguments to match number of operands}} - linalg.generic { - args_in = 0, - args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0: memref -} - -// ----- - -func @foo(%0: i32) { return } - func @generic_mismatched_num_returns(%arg0: memref) { - // expected-error @+1 {{op expected function results(0) to match number of outputs(1)}} + // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (0)}} linalg.generic { args_in = 0, args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0: memref -} - -// ----- - -func @foo(%0: i32, %1: i32, %2: i32) { return } - -func @generic_mismatched_num_returns(%0: memref, %1: memref) { - // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of operand 2}} - linalg.generic { - args_in = 3, - args_out = 0, - fun = @foo, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %0, %1, %1: memref, memref, memref -} - -// ----- - -func @foo(%0: i32, %1: i32, %2: f32) -> i32 { return %1: i32} - -func @generic_mismatched_num_returns(%0: memref, %1: memref) { - // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}} - linalg.generic { - args_in = 2, - args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (0)> ], + indexing_maps = [ affine_map<() -> ()> ], iterator_types = [] - } %0, %0, %1: memref, memref, memref + } %arg0 { + ^bb(%0: f32): + linalg.yield + }: memref } // ----- -func @foo(%0: i32) -> i32 { return %0: i32 } - func @generic_symbol_in_map(%arg0: memref) { // expected-error @+1 {{op expected indexing_map #0 to have no symbols}} linalg.generic { args_in = 0, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<()[N] -> (0)> ], iterator_types = ["parallel"] - } %arg0: memref + } %arg0 { + ^bb(%i : i32): + }: memref } // ----- -func @foo(%0: i32) -> i32 { return %0: i32 } - func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { args_in = 0, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<() -> (0)> ], iterator_types = ["parallel"] - } %arg0: memref<1xi32> + } %arg0 { + ^bb(%i : i32): + }: memref<1xi32> } // ----- -func @foo(%0: f32) -> f32 { return %0: f32 } - func @generic_one_d_view(%arg0: memref(off + i)>>) { // expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref (d0 + s0)>>'}} linalg.generic { args_in = 0, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<() -> (0, 0)> ], iterator_types = [] - } %arg0: memref(off + i)>> -} - -// ----- - -func @foo(%0: i32) -> f32 { - %1 = constant 0.0: f32 - return %1: f32 -} - -func @generic_fun_arg_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of operand 1}} - linalg.generic { - args_in = 0, - args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0: memref(off + i)>> + } %arg0 { + ^bb(%f : f32): + linalg.yield %f: f32 + }: memref(off + i)>> } // ----- -func @foo(%0: f32) -> i4 { - %1 = constant 1: i4 - return %1: i4 -} - -func @generic_fun_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}} +func @generic_result_0_element_type(%arg0: memref(off + i)>>) { + // expected-error @+9 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { args_in = 0, args_out = 1, - fun = @foo, - indexing_maps = [ affine_map<() -> (0)> ], - iterator_types = [] - } %arg0: memref(off + i)>> + indexing_maps = [ affine_map<(i) -> (i)> ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%0: f32): + %1 = constant 1: i4 + linalg.yield %1: i4 + }: memref(off + i)>> } // ----- -func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 } - func @generic_singular_maps(%arg0: memref(off + i)>>, %arg1: memref(off + i)>>) { // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}} linalg.generic { args_in = 1, args_out = 1, - fun = @foo, indexing_maps = [ affine_map<(i, j) -> (i + j)>, affine_map<(i, j) -> (i + j)> ], iterator_types = ["parallel","parallel"] - } %arg0, %arg1: memref(off + i)>>, memref(off + i)>> + } %arg0, %arg1 { + ^bb(%0: f32, %1: f32): + linalg.yield %1: f32 + }: memref(off + i)>>, + memref(off + i)>> } //////////////////////////////////////////////////////////////////////////////// @@ -341,88 +272,53 @@ // ----- -func @foo(%f: f32) -> (f32) { - return %f : f32 -} -func @indexed_generic_fun_arg_count(%arg0: memref) { - // expected-error @+1 {{op expected function arguments to match number of loops + number of operands}} - linalg.indexed_generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"], - fun = @foo - } %arg0: memref -} - -// ----- - -func @foo(%i: i32, %val: f32) -> (f32) { - return %val : f32 -} -func @indexed_generic_fun_induction_var_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected function argument 1 to be an index}} - linalg.indexed_generic { - args_in = 0, - args_out = 1, - iterator_types = ["parallel"], - indexing_maps = [ affine_map<(i) -> (i)> ], - fun = @foo - } %arg0 : memref -} - -// ----- - -func @foo(%i: index, %val: i1) -> (i1) { - return %val : i1 -} -func @indexed_generic_fun_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of input 1}} +func @indexed_generic_arg_count(%arg0: memref) { + // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}} linalg.indexed_generic { args_in = 0, args_out = 1, - indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"], - fun = @foo - } %arg0: memref + indexing_maps = [ affine_map<()[] -> ()> ], + iterator_types = [] + } %arg0 { + ^bb(%0: index, %1: f32): + linalg.yield %1: f32 + } : memref + return } // ----- -func @foo(%i: index, %val: i1) -> (i1, i1) { - return %val, %val : i1, i1 -} -func @indexed_generic_fun_result_count(%arg0: memref) { - // expected-error @+1 {{op expected function results to match number of outputs}} +func @indexed_generic_induction_var_arg_type(%arg0: memref) { + // expected-error @+1 {{op expected block argument 1 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, - indexing_maps = [ affine_map<(d0) -> (d0)> ], iterator_types = ["parallel"], - fun = @foo - } %arg0: memref + indexing_maps = [ affine_map<(i) -> (i)> ] + } %arg0 { + ^bb(%0: i32, %1: f32): + linalg.yield %1: f32 + } : memref } // ----- -func @foo(%i: index, %val: i32) -> (f32) { - %val_float = sitofp %val : i32 to f32 - return %val_float : f32 -} -func @indexed_generic_fun_result_count(%arg0: memref) { - // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'i32' of output 1}} +func @indexed_generic_result_count(%arg0: memref) { + // expected-error @+8 {{op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}} linalg.indexed_generic { args_in = 0, args_out = 1, indexing_maps = [ affine_map<(d0) -> (d0)> ], - iterator_types = ["parallel"], - fun = @foo - } %arg0: memref + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: index, %val: f32): + linalg.yield %val, %val: f32, f32 + }: memref } // ----- -func @generic_fun_result_0_element_type(%arg0: memref(off + i)>>) { +func @generic_result_0_element_type(%arg0: memref(off + i)>>) { // expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { args_in = 0, @@ -453,7 +349,7 @@ // ----- -func @generic_fun_result_0_element_type(%arg0: memref) { +func @generic_result_0_element_type(%arg0: memref) { // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}} linalg.dot(%arg0, %arg0): memref, memref } @@ -524,3 +420,11 @@ memref, memref<2x3xf32>, memref return } + +// ----- + +func @named_ops(%a3: memref, %b3: memref, %c3: memref) { + // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref'}} + linalg.batchmatmul(%a3, %b3, %c3): (memref, memref, memref) -> () + return +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -533,51 +533,11 @@ // CHECKPARALLEL: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32 // CHECKPARALLEL: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref -func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { - %f0 = constant 0.0 : f32 - return %f0, %f0 : f32, f32 -} #accesses = [ affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (i, j, k)>, affine_map<(i, j, k) -> (i, k, j)> ] -#trait = { - args_in = 1, - args_out = 2, - iterator_types = ["parallel", "parallel", "parallel"], - indexing_maps = #accesses, - fun = @foo, - library_call = "some_external_function_name_1", - doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))" -} -func @generic_function(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.generic #trait %arg0, %arg1, %arg2: - memref, memref, memref - return -} -// CHECKLOOP-LABEL: @foo -// CHECKLOOP-LABEL: @generic_function -// CHECKLOOP: loop.for %[[i:.*]] = {{.*}} -// CHECKLOOP: loop.for %[[j:.*]] = {{.*}} -// CHECKLOOP: loop.for %[[k:.*]] = {{.*}} -// CHECKLOOP: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref -// CHECKLOOP: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKLOOP: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref -// CHECKLOOP: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32) -// CHECKLOOP: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKLOOP: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref - -// CHECKPARALLEL-LABEL: @foo -// CHECKPARALLEL-LABEL: @generic_function -// CHECKPARALLEL: loop.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]]) -// CHECKPARALLEL: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref -// CHECKPARALLEL: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKPARALLEL: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref -// CHECKPARALLEL: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32) -// CHECKPARALLEL: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKPARALLEL: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref - #trait2 = { args_in = 1, args_out = 2, @@ -617,52 +577,6 @@ // CHECKPARALLEL: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref // CHECKPARALLEL: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref -func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) -> (f32, f32) { - %i_int = index_cast %i: index to i32 - %i_float = sitofp %i_int : i32 to f32 - return %i_float, %i_float : f32, f32 -} -#trait3 = { - args_in = 1, - args_out = 2, - iterator_types = ["parallel", "parallel", "parallel"], - indexing_maps = #accesses, - fun = @indexed_foo, - library_call = "some_external_function_name_1", - doc = "b(i,j,k), c(i,k,j) = foo(a(i, j), b(i,j,k), c(i,k,j))" -} -func @indexed_generic_function( - %arg0: memref, - %arg1: memref, - %arg2: memref) { - linalg.indexed_generic #trait3 %arg0, %arg1, %arg2: - memref, - memref, - memref - return -} -// CHECKLOOP-LABEL: @indexed_foo -// CHECKLOOP-LABEL: @indexed_generic_function -// CHECKLOOP: loop.for %[[i:.*]] = {{.*}} -// CHECKLOOP: loop.for %[[j:.*]] = {{.*}} -// CHECKLOOP: loop.for %[[k:.*]] = {{.*}} -// CHECKLOOP: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref -// CHECKLOOP: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKLOOP: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref -// CHECKLOOP: %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32) -// CHECKLOOP: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKLOOP: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref - -// CHECKPARALLEL-LABEL: @indexed_foo -// CHECKPARALLEL-LABEL: @indexed_generic_function -// CHECKPARALLEL: loop.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]]) -// CHECKPARALLEL: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref -// CHECKPARALLEL: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKPARALLEL: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref -// CHECKPARALLEL: %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32) -// CHECKPARALLEL: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKPARALLEL: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref - #trait4 = { args_in = 1, args_out = 2, diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -289,11 +289,6 @@ // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { - %f0 = constant 0.0 : f32 - return %f0 : f32 -} - #accesses = [ affine_map<(i, j, k) -> (j, i)>, affine_map<(i, j, k) -> (i, k, i + j)> @@ -304,46 +299,45 @@ args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], - fun = @foo, library_call = "some_external_function_name_1" } func @generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait %arg0, %arg1 {foo = 1} : - memref, offset: ?, strides: [?, 1]>, - memref + linalg.generic #trait {foo = 1} %arg0, %arg1 { + ^bb(%0: vector<3x4xi4>, %1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } : memref, offset: ?, strides: [?, 1]>, + memref return } -// CHECK-LABEL: func @foo // CHECK-LABEL: func @generic -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1" -// CHECK-SAME: {foo = 1 : i64}: -// CHECK-SAME: memref, #[[strided2D]]>, memref +// CHECK-SAME: {foo = 1 : i64} +// CHECK: memref, #[[strided2D]]>, memref func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref) { - linalg.generic #trait %arg0, %arg1 {foo = 1} : - tensor>, - memref + linalg.generic #trait {foo = 1} %arg0, %arg1 { + ^bb(%0: vector<3x4xi4>, %1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } : tensor>, + memref return } // CHECK-LABEL: func @generic_with_tensor_input -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} -// CHECK-SAME: {foo = 1 : i64}: -// CHECK-SAME: tensor>, memref +// CHECK-SAME: {foo = 1 : i64} +// CHECK: tensor>, memref // ----- -func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { - %f0 = constant 0.0 : f32 - return %f0 : f32 -} - #accesses = [ affine_map<(i, j, k) -> (j, i)>, affine_map<(i, j, k) -> (i, k, i + j)> @@ -354,31 +348,30 @@ args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], - fun = @foo, library_call = "some_external_function_name_1" } func @generic_with_tensor_input_and_output( %arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : - tensor>, tensor -> tensor + %0 = linalg.generic #trait2 {foo = 1} %arg0, %arg1 { + ^bb(%0: vector<3x4xi4>, %1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } : tensor>, tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @generic_with_tensor_input_and_output -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], -// CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: -// CHECK-SAME: tensor>, tensor -> tensor +// CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: {foo = 1 : i64} +// CHECK-SAME: %{{.*}}, %{{.*}} +// CHECK: tensor>, tensor -> tensor // CHECK: return {{.*}} : tensor // ----- -func @foo(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) -> f32 { - %f0 = constant 0.0 : f32 - return %f0 : f32 -} - #accesses = [ affine_map<(i, j, k) -> (j, i)>, affine_map<(i, j, k) -> (i, k, i + j)> @@ -389,22 +382,26 @@ args_out = 1, indexing_maps = #accesses, iterator_types = ["parallel", "parallel", "parallel"], - fun = @foo, library_call = "some_external_function_name_1" } func @indexed_generic_with_tensor_input_and_output( %arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.indexed_generic #trait2 %arg0, %arg1 {foo = 1} : - tensor>, tensor -> tensor + %0 = linalg.indexed_generic #trait2 {foo = 1} %arg0, %arg1 { + ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } : tensor>, tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output -// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, +// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], -// CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: -// CHECK-SAME: tensor>, tensor -> tensor +// CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: {foo = 1 : i64} +// CHECK-SAME: %{{.*}}, %{{.*}} +// CHECK: tensor>, tensor -> tensor // CHECK: return {{.*}} : tensor // ----- @@ -460,10 +457,10 @@ func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait3 %arg0, %arg1 { + linalg.generic #trait3 {foo = 1} %arg0, %arg1 { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - } {foo = 1}: memref, offset: ?, strides: [?, 1]>, + } : memref, offset: ?, strides: [?, 1]>, memref return } @@ -471,17 +468,18 @@ // CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_2" +// CHECK-SAME: {foo = 1 : i64} // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, -// CHECK-SAME: memref +// CHECK: memref, #[[strided2D]]>, +// CHECK-SAME: memref func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.indexed_generic #trait3 %arg0, %arg1 { + linalg.indexed_generic #trait3 {foo = 1} %arg0, %arg1 { ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - } {foo = 1}: memref, offset: ?, strides: [?, 1]>, + }: memref, offset: ?, strides: [?, 1]>, memref return } @@ -489,9 +487,10 @@ // CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_2" +// CHECK-SAME: {foo = 1 : i64} // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, +// CHECK: }: memref, #[[strided2D]]>, // CHECK-SAME: memref // ----- @@ -621,3 +620,16 @@ // CHECK-SAME: memref into memref // CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] // CHECK-SAME: memref into memref + + +// TODO(ntv): Return tensors need a semantics convention update. +func @named_ops(%a3: memref, %b3: memref, %c3: memref, + %ta3: tensor, %tb3: tensor, %tc3: tensor) { + linalg.batchmatmul(%a3, %b3, %c3): (memref, memref, memref) -> () + linalg.batchmatmul(%ta3, %tb3, %c3): (tensor, tensor, memref) -> () + return +} +// CHECK-LABEL: func @named_ops +// CHECK: linalg.batchmatmul +// CHECK: linalg.batchmatmul + diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -212,57 +212,71 @@ // CHECK-LABEL: func @test_vectorize_fill // CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> -func @fma(%a: f32, %b: f32, %c: f32) -> f32 { - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 - return %e: f32 - } #matmul_accesses = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> ] #generic_matmul_trait = { - args_in = 2, - args_out = 1, - fun = @fma, - indexing_maps = #matmul_accesses, - library_call = "linalg_matmul", - iterator_types = ["parallel", "parallel", "reduction"] - } + args_in = 2, + args_out = 1, + indexing_maps = #matmul_accesses, + library_call = "linalg_matmul", + iterator_types = ["parallel", "parallel", "reduction"] +} func @permute_generic(%A: memref, %B: memref, %C: memref) { - linalg.generic #generic_matmul_trait %A, %B, %C : memref, memref, memref - + linalg.generic #generic_matmul_trait %A, %B, %C { + ^bb(%a: f32, %b: f32, %c: f32): + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + linalg.yield %e: f32 + }: memref, + memref, + memref return } // CHECK-LABEL : func @fma // CHECK-LABEL : func @permute_generic -// CHECK : linalg.generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref +// CHECK : linalg.generic {args_in = 2, args_out = 1, +// CHECK-SAME : indexing_maps = [#[[kn]], #[[nm]], #[[km]]], +// CHECK-SAME : iterator_types = ["parallel", "reduction", "parallel"], +// CHECK-SAME : library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}} +// CHECK : memref, +// CHECK-SAME : memref, +// CHECK-SAME : memref -func @fma_indexed(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) -> f32 { - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 - return %e: f32 -} #indexed_matmul_trait = { - args_in = 2, - args_out = 1, - fun = @fma_indexed, - indexing_maps = #matmul_accesses, - library_call = "linalg_matmul_indexed", - iterator_types = ["parallel", "parallel", "reduction"] + args_in = 2, + args_out = 1, + indexing_maps = #matmul_accesses, + library_call = "linalg_matmul_indexed", + iterator_types = ["parallel", "parallel", "reduction"] } -func @permute_generic_indexed(%A: memref, - %B: memref, - %C: memref) { - linalg.indexed_generic #indexed_matmul_trait %A, %B, %C : memref, memref, memref +func @permute_generic_indexed( + %A: memref, + %B: memref, + %C: memref) { + linalg.indexed_generic #indexed_matmul_trait %A, %B, %C { + ^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + linalg.yield %e: f32 + } : memref, + memref, + memref return } // CHECK-LABEL : func @fma_indexed // CHECK-LABEL : func @permute_generic_indexed -// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, fun = @fma, indexing_maps = [#[[kn]], #[[nm]], #[[km]]], iterator_types = ["parallel", "reduction", "parallel"], library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : memref, memref, memref +// CHECK : linalg.indexed_generic {args_in = 2, args_out = 1, +// CHECK-SAME : indexing_maps = [#[[kn]], #[[nm]], #[[km]]], +// CHECK-SAME : iterator_types = ["parallel", "reduction", "parallel"], +// CHECK-SAME : library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}} : +// CHECK : memref, +// CHECK-SAME : memref, +// CHECK-SAME : memref func @dot_perm(%x: memref, %y: memref, diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -9,6 +9,7 @@ // RUN: mlir-edsc-builder-api-test | FileCheck %s -dump-input-on-failure #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/LoopOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -111,7 +111,7 @@ HasLinalgTransformMarker<"VECTORIZE">, PreconditionVectorizeLinalgOp ]>>)]>; -def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), +def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), [(VectorizeLinalgOp)], [(Constraint, @@ -122,7 +122,7 @@ //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), +def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_), (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), [(Constraint ]>>)]>; -def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), +def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_), (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), [(Constraint, // ODS-NEXT: NOutputs<1>, // ODS-NEXT: NamedStructuredOpTraits]> // -// IMPL-LABEL: matvec::referenceIterators() { +// IMPL-LABEL: test1Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: matvec::referenceIndexingMaps() { +// IMPL: test1Op::referenceIndexingMaps() { // IMPL: AffineMap::get(2, 0, {d0, d1}), // IMPL-NEXT: AffineMap::get(2, 0, {d1}), // IMPL-NEXT: AffineMap::get(2, 0, {d0}) }; // -// IMPL: matvec::regionBuilder(ArrayRef args) { +// IMPL: test1Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { +def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { C(m) = std_addf(std_mulf(A(m, k), B(k))); } -// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [ +// ODS-LABEL: def test2Op : LinalgNamedStructured_Op<"test2", [ // ODS-NEXT: NInputs<2>, // ODS-NEXT: NOutputs<1>, // ODS-NEXT: NamedStructuredOpTraits]> // -// IMPL-LABEL: matmul::referenceIterators() { +// IMPL-LABEL: test2Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: matmul::referenceIndexingMaps() { +// IMPL: test2Op::referenceIndexingMaps() { // IMPL: AffineMap::get(3, 0, {d0, d2}), // IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}), // IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}) }; // -// IMPL: matmul::regionBuilder(ArrayRef args) { +// IMPL: test2Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { +def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } -// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ +// ODS-LABEL: def test3Op : LinalgNamedStructured_Op<"test3", [ // ODS-NEXT: NInputs<2>, // ODS-NEXT: NOutputs<1>, // ODS-NEXT: NamedStructuredOpTraits]> // -// IMPL-LABEL: batchmatmul::referenceIterators() { +// IMPL-LABEL: test3Op::referenceIterators() { // IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } // -// IMPL: batchmatmul::referenceIndexingMaps() { +// IMPL: test3Op::referenceIndexingMaps() { // IMPL: AffineMap::get(4, 0, {d0, d1, d3}), // IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}), // IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}) }; // -// IMPL: batchmatmul::regionBuilder(ArrayRef args) { +// IMPL: test3Op::regionBuilder(Block &block) { // IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); // IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // -// TBLGEN: batchmatmulOp -def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { +// TBLGEN: test3Op +def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -896,7 +896,8 @@ TensorExpr(StringRef name, SmallVectorImpl> &&exprs, ArrayRef reductionDims) - : Expression(Kind::TensorExpr), opId(name), expressions(std::move(exprs)), + : Expression(Kind::TensorExpr), operationName(name), + expressions(std::move(exprs)), reductionDimensions(reductionDims.begin(), reductionDims.end()) {} static bool classof(const Expression *e) { @@ -904,7 +905,7 @@ } bool operator==(const TensorExpr &other) const { - if (opId != other.opId) + if (operationName != other.operationName) return false; if (expressions.size() != other.expressions.size()) return false; @@ -922,7 +923,7 @@ template void visit(Lambda callback) const; - StringRef opId; + StringRef operationName; SmallVector, 4> expressions; SetVector reductionDimensions; }; @@ -995,15 +996,15 @@ StringRef linalgOpName); /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. - void printReferenceIterators(llvm::raw_ostream &os, StringRef opId, + void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. - void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, + void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. - void printRegionBuilder(llvm::raw_ostream &os, StringRef opId, + void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state); private: @@ -1364,6 +1365,7 @@ return failure(); StringRef tcName = parser.curToken.getSpelling(); + std::string cppOpName = (tcName + "Op").str(); LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing tc: " << tcName << "\n"); if (failed(parser.parseToken(Token::Kind::id, "expected id")) || failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) @@ -1404,7 +1406,7 @@ SmallVector perComprehensionStates; while (parser.curToken.isNot(Token::Kind::r_brace)) { perComprehensionStates.push_back(ComprehensionParsingState()); - if (failed(parseOneComprehension(tcName, tcName, + if (failed(parseOneComprehension(cppOpName, tcName, perComprehensionStates.back()))) return failure(); }; @@ -1418,16 +1420,16 @@ return failure(); } if (genODSDecl) { - printODS(os, tcName, tcName); + printODS(os, cppOpName, tcName); os << "\n"; } if (genODSImpl) { auto &state = perComprehensionStates.back(); std::string extraMethods; llvm::raw_string_ostream ss(extraMethods); - printReferenceIterators(ss, tcName, state); - printReferenceIndexingMaps(ss, tcName, state); - printRegionBuilder(ss, tcName, state); + printReferenceIterators(ss, cppOpName, state); + printReferenceIndexingMaps(ss, cppOpName, state); + printRegionBuilder(ss, cppOpName, state); ss.flush(); os << extraMethods << "\n"; } @@ -1442,18 +1444,31 @@ /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName) { - const char *header = R"FMT( def {0}Op : LinalgNamedStructured_Op<"{1}", [ + const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [ NInputs<{2}>, NOutputs<{3}>, NamedStructuredOpTraits]> { let arguments = (ins Variadic:$views); let results = (outs Variadic:$output_tensors); + let regions = (region SizedRegion<1>:$region); + let builders = [OpBuilder< + "Builder *b, OperationState &result, TypeRange outputTypes, " + # "ValueRange views", + [{{ + result.addOperands(views); + result.addTypes(outputTypes); + buildNamedStructuredOpRegion<{0}>( + *b, result, TypeRange(views), outputTypes); + }]> + ]; + let parser = [{ + return ::parseNamedStructuredOp<{0}>(parser, result); + }]; let extraClassDeclaration = [{{ llvm::Optional> referenceIterators(); llvm::Optional> referenceIndexingMaps(); - void regionBuilder(ArrayRef args); + static void regionBuilder(Block &block); }]; - let hasFolder = 1; })FMT"; unsigned nInputs = 0, nOutputs = 0; @@ -1468,7 +1483,8 @@ } /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. -void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId, +void TCParser::printReferenceIterators(llvm::raw_ostream &os, + StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceReferenceIteratorsFmt = R"FMT( @@ -1498,11 +1514,12 @@ }); ss.flush(); - os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr); + os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr); } /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. -void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, +void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, + StringRef cppOpName, ComprehensionParsingState &state) { const char *referenceIndexingMapsFmt = R"FMT( @@ -1544,11 +1561,11 @@ }); mapsStringStream.flush(); - os << llvm::formatv(referenceIndexingMapsFmt, opId, dimsStr, mapsStr); + os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr); } /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. -void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId, +void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state) { unsigned count = state.orderedTensorArgs.size(); llvm::DenseMap subExprsMap; @@ -1570,15 +1587,17 @@ }); subExprsStringStream.flush(); const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});"; - os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs); + os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, + subExprs); subExprsMap[pTensorExpr] = count; } }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(ArrayRef args) { + void {0}::regionBuilder(Block &block) { using namespace edsc; using namespace intrinsics; + auto args = block.getArguments(); ValueHandle {1}; {2} (linalg_yield(ValueRange{ {3} })); @@ -1612,8 +1631,8 @@ expressionStringStream.flush(); yieldStringStream.flush(); - os << llvm::formatv(regionBuilderFmt, opId, valueHandleStr, expressionsStr, - yieldStr); + os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr, + expressionsStr, yieldStr); } /// Iterate over each Tensor Comprehension def.