diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -87,4 +88,62 @@ let hasVerifier = 1; } +def Linalg_SoftmaxOp : Linalg_Op<"softmax", + [DestinationStyleOpInterface, + PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, + DeclareOpInterfaceMethods]> { + let summary = "Softmax operator"; + let description = [{ + linalg.softmax computes a numerically stable version of softmax. + + For a given input tensor and a specified dimension `d`, compute: + 1. the max `m` along that dimension `d` + 2. f(x) = exp(x - m) + 3. sum f(x) along dimension d to get l(x). + 4. compute the final result f(x) / l(x). + + This is an aggregate linalg operation that further reduces to a small DAG of + structured operations. + }]; + + let arguments = (ins AnyShaped:$input, + AnyShaped:$output, + I64Attr:$dimension + ); + + let results = (outs Variadic:$result); + let hasFolder = 1; + let assemblyFormat = [{ + attr-dict + `dimension` `(` $dimension `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + (`->` type($result)^)? + }]; + + let extraClassDeclaration = [{ + ShapedType getInputOperandType() { + return getInput().getType().cast(); + } + ShapedType getOutputOperandType() { + return getOutput().getType().cast(); + } + int64_t getInputOperandRank() { + return getInputOperandType().getRank(); + } + int64_t getOutputOperandRank() { + return getOutputOperandType().getRank(); + } + // Method to implement DestinationStyleOpInterface. + std::pair getDpsInitsPositionRange() { + std::pair outputsIndexAndLength = + getODSOperandIndexAndLength(1); + return std::make_pair( + outputsIndexAndLength.first, + outputsIndexAndLength.first + outputsIndexAndLength.second); + } + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2140,6 +2140,39 @@ // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. +//===----------------------------------------------------------------------===// +// SoftmaxOp +//===----------------------------------------------------------------------===// + +LogicalResult SoftmaxOp::verify() { + ShapedType inputType = getInputOperandType(); + ShapedType outputType = getOutputOperandType(); + + ArrayRef inputShape = inputType.getShape(); + ArrayRef outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(inputShape, outputShape))) + return emitOpError("incompatible output shape"); + + int64_t inputRank = getInputOperandRank(); + int64_t dimension = getDimension(); + if ((dimension < 0) || (dimension >= inputRank)) + return emitOpError("incorrect dimension specified"); + + return success(); +} + +// cast(dynamic) -> static. +LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl &) { + return memref::foldMemRefCast(*this); +} + +LogicalResult +SoftmaxOp::reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return cast(getOperation()) + .reifyResultShapes(b, reifiedReturnShapes); +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -733,3 +733,14 @@ linalg.generic {} ins() outs() return } + +// ----- + +func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x16xf32> { + %0 = tensor.empty() : tensor<2x16xf32> + // expected-error @+1 {{incompatible output shape}} + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) + outs(%0: tensor<2x16xf32>) + -> tensor<2x16xf32> + return %1 : tensor<2x16xf32> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -599,3 +599,17 @@ // CHECK-SAME: outs // CHECK-SAME: dimensions = [1] // CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32> + +// ----- + +func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %0 = tensor.empty() : tensor<2x16x32xf32> + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> + return %1 : tensor<2x16x32xf32> +} +// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK: %[[D1:.+]] = linalg.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] : +// CHECK-SAME: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> +// CHECK: return %[[D1]] : tensor<2x16x32xf32> +// CHECK: }