diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -225,12 +225,77 @@ //===----------------------------------------------------------------------===// -// Reduce op. +// Map op. //===----------------------------------------------------------------------===// def TensorOrMemref : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; +def MapOp : LinalgStructuredBase_Op<"map", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Elementwise operations"; + let description = [{ + Models elementwise operations on tensors in terms of arithmetic operations + on the corresponding elements. + + Example: + ``` + %add = linalg.map + ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) + outs(%init: tensor<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + ``` + }]; + + let arguments = (ins + // Input args + Variadic:$inputs, + + // Output arg + TensorOrMemref:$init + ); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$mapper); + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Implement functions necessary for LinalgStructuredInterface. + ArrayAttr getIteratorTypes(); + ArrayAttr getIndexingMaps(); + std::string getLibraryCallName() { + return "op_has_no_registered_library_name"; + } + + // Implement functions necessary for DestinationStyleOpInterface. + unsigned getNumInputs() { + return this->getOperation()->getNumOperands() - getNumOutputs(); + }; + unsigned getNumOutputs() { return 1; }; + mlir::ValueRange getOutputs() { return getOperands().take_back(1); } + linalg::OpOperandVector getOpOperandsMatchingBBargs() { + return getInputOperands(); + } + + static std::function)> + getRegionBuilder() { + return nullptr; + } + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + + +//===----------------------------------------------------------------------===// +// Reduce op. +//===----------------------------------------------------------------------===// + def ReduceOp : LinalgStructuredBase_Op<"reduce", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, @@ -264,7 +329,7 @@ ConfinedAttr]>:$dimensions ); - let results = (outs Variadic); + let results = (outs Variadic); let regions = (region SizedRegion<1>:$combiner); let extraClassDeclaration = structuredOpsBaseDecls # [{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1288,6 +1288,135 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// MapOp +//===----------------------------------------------------------------------===// + +static ParseResult parseDstStyleOp( + OpAsmParser &parser, OperationState &result, + function_ref parseAttrsFn = + nullptr) { + // Parse `ins` and `outs`. + SmallVector inputTypes, outputTypes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, + /*addOperandSegmentSizes=*/false)) + return failure(); + + // Add result types. + for (Type outputType : outputTypes) { + if (outputType.isa()) + result.addTypes(outputType); + } + + // Parse required attributes. + if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes))) + return failure(); + + // Parse optional attributes. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + return success(); +} + +void MapOp::getAsmBlockArgumentNames(Region ®ion, + OpAsmSetValueNameFn setNameFn) { + for (Value v : getRegionInputArgs()) + setNameFn(v, "in"); +} + +void MapOp::getAsmResultNames(function_ref setNameFn) { + if (!getResults().empty()) + setNameFn(getResults().front(), "mapped"); +} + +ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { + if (parseDstStyleOp(parser, result)) + return failure(); + + SmallVector regionArgs; + if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) { + return failure(); + } + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); + + return success(); +} + +void MapOp::print(OpAsmPrinter &p) { + printCommonStructuredOpParts(p, getInputs(), getOutputs()); + p.printOptionalAttrDict((*this)->getAttrs()); + + p << "("; + llvm::interleaveComma(getMapper().getArguments(), p, + [&](auto arg) { p.printRegionArgument(arg); }); + p << ") "; + + p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); +} + +LogicalResult MapOp::verify() { + auto *bodyBlock = getBody(); + auto blockArgs = bodyBlock->getArguments(); + + // Checks if the number of `inputs` match the arity of the `mapper` region. + if (getInputs().size() != blockArgs.size()) + return emitOpError() << "expects number of operands to match the arity of " + "mapper, but got: " + << getInputs().size() << " and " << blockArgs.size(); + + // The parameters of mapper should all match the element type // of inputs. + for (const auto &[bbArgType, inputArg] : + llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { + auto inputElemType = inputArg.getType().cast().getElementType(); + if (bbArgType != inputElemType) { + return emitOpError() << "expected element type of input " << inputElemType + << " to match bbArg type " << bbArgType; + } + } + + // The shape of each input must match the shape of the output. + auto outputShape = + getOutputs().front().getType().cast().getShape(); + for (Type inputArgType : TypeRange{getInputs()}) { + auto inputElemShape = inputArgType.cast().getShape(); + if (inputElemShape != outputShape) { + return emitOpError() << "expected shape of input (" << inputElemShape + << ") to match shape of output (" << outputShape + << ")"; + } + } + + return success(); +} + +ArrayAttr MapOp::getIteratorTypes() { + int64_t rank = getInit().getType().getRank(); + return Builder(getContext()) + .getStrArrayAttr( + SmallVector(rank, getParallelIteratorTypeName())); +} + +ArrayAttr MapOp::getIndexingMaps() { + Builder builder(getContext()); + int64_t rank = getInit().getType().getRank(); + int64_t numIndexingMaps = getOperands().size(); + return builder.getAffineMapArrayAttr(SmallVector( + numIndexingMaps, builder.getMultiDimIdentityMap(rank))); +} + +void MapOp::getEffects( + SmallVectorImpl> + &effects) { + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// @@ -1302,7 +1431,8 @@ void ReduceOp::getAsmResultNames( function_ref setNameFn) { - setNameFn(getResults().front(), "reduced"); + if (!getResults().empty()) + setNameFn(getResults().front(), "reduced"); } ArrayAttr ReduceOp::getIteratorTypes() { @@ -1336,33 +1466,6 @@ outputBuffers); } -static ParseResult parseDstStyleOp( - OpAsmParser &parser, OperationState &result, - function_ref parseAttrsFn = - nullptr) { - // Parse `ins` and `outs`. - SmallVector inputTypes, outputTypes; - if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, - /*addOperandSegmentSizes=*/false)) - return failure(); - - // Add result types. - for (Type outputType : outputTypes) { - if (!outputType.isa()) - return failure(); - result.addTypes(outputType); - } - - // Parse required attributes. - if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes))) - return failure(); - - // Parse optional attributes. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - return success(); -} - static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName) { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -391,6 +391,70 @@ // ----- +func.func @map_binary_wrong_yield_operands( + %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) + -> tensor<64xf32> { + %add = linalg.map + ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + // expected-error @+1{{'linalg.yield' op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}} + linalg.yield %0, %0: f32, f32 + } + func.return %add : tensor<64xf32> +} + +// ----- + +func.func @map_input_mapper_arity_mismatch( + %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) + -> tensor<64xf32> { + // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}} + %add = linalg.map + ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + func.return %add : tensor<64xf32> +} + +// ----- + +func.func @map_input_mapper_type_mismatch( + %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) + -> tensor<64xf32> { + // expected-error@+1{{'linalg.map' op expected element type of input 'f32' to match bbArg type 'f64'}} + %add = linalg.map + ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%lhs_elem: f64, %rhs_elem: f64) { + %0 = arith.addf %lhs_elem, %rhs_elem: f64 + linalg.yield %0: f64 + } + func.return %add : tensor<64xf32> +} + +// ----- + +func.func @map_input_output_shape_mismatch( + %lhs: tensor<64x64xf32>, %rhs: tensor<64x64xf32>, %init: tensor<32xf32>) + -> tensor<32xf32> { + // expected-error@+1{{'linalg.map' op expected shape of input (64, 64) to match shape of output (32)}} + %add = linalg.map + ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>) + outs(%init:tensor<32xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + func.return %add : tensor<32xf32> +} + +// ----- + func.func @reduce_input_vs_init_dimension_mismatch( %input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -354,8 +354,70 @@ // ----- +func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, + %init: tensor<64xf32>) -> tensor<64xf32> { + %add = linalg.map + ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + func.return %add : tensor<64xf32> +} +// CHECK-LABEL: func @map_binary +// CHECK: linalg.map + +// ----- + +func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>, + %init: memref<64xf32>) { + linalg.map + ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>) + outs(%init:memref<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + func.return +} +// CHECK-LABEL: func @map_binary_memref +// CHECK: linalg.map + +// ----- + +func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { + %abs = linalg.map + ins(%input:tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%input_elem: f32) { + %0 = math.absf %input_elem: f32 + linalg.yield %0: f32 + } + func.return %abs : tensor<64xf32> +} +// CHECK-LABEL: func @map_unary +// CHECK: linalg.map + +// ----- + +func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) { + linalg.map + ins(%input:memref<64xf32>) + outs(%init:memref<64xf32>) + (%input_elem: f32) { + %0 = math.absf %input_elem: f32 + linalg.yield %0: f32 + } + func.return +} +// CHECK-LABEL: func @map_unary_memref +// CHECK: linalg.map + +// ----- + func.func @reduce(%input: tensor<16x32x64xf32>, - %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { %reduce = linalg.reduce ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) @@ -371,6 +433,23 @@ // ----- +func.func @reduce_memref(%input: memref<16x32x64xf32>, + %init: memref<16x64xf32>) { + linalg.reduce + ins(%input:memref<16x32x64xf32>) + outs(%init:memref<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %in, %out: f32 + linalg.yield %0: f32 + } + func.return +} +// CHECK-LABEL: func @reduce_memref +// CHECK: linalg.reduce + +// ----- + func.func @variadic_reduce(%input1: tensor<16x32x64xf32>, %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>, %init2: tensor<16x64xi64>) -> (tensor<16x64xf32>, tensor<16x64xi64>) { @@ -387,3 +466,22 @@ } // CHECK-LABEL: func @variadic_reduce // CHECK: linalg.reduce + +// ----- + +func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>, + %init1: memref<16x64xf32>, %input2: memref<16x32x64xi64>, + %init2: memref<16x64xi64>) { + linalg.reduce + ins(%input1, %input2 : memref<16x32x64xf32>, memref<16x32x64xi64>) + outs(%init1, %init2 : memref<16x64xf32>, memref<16x64xi64>) + dimensions = [1] + (%in1: f32, %in2: i64, %out1: f32, %out2: i64) { + %0 = arith.addf %in1, %out1: f32 + %1 = arith.addi %in2, %out2: i64 + linalg.yield %0, %1: f32, i64 + } + func.return +} +// CHECK-LABEL: func @variadic_reduce_memref +// CHECK: linalg.reduce