diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -897,4 +897,34 @@ let verifyWithRegions = 1; } +def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> { + let description = [{ + Interface for decomposing aggregated operations into a sequence of simpler + ops. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Method to decompose the operation into simpler operations. + + On success, this method returns one `Value` per result in the + original operation. + The order of the returned values must match the order of the + original values. + In other words, the returned vector can be used directly with + `RewriterBase::replaceOp(this, returnedValues)`. + }], + /*retType=*/"FailureOr>", + /*methodName=*/"decomposeOperation", + /*args=*/(ins + "OpBuilder &":$b), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return {}; + }] + > + ]; +} + #endif // LINALG_IR_LINALGINTERFACES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -14,6 +14,7 @@ #define LINALG_OPS include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -93,6 +94,7 @@ [DestinationStyleOpInterface, PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods { + let description = [{ + TODO + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} //===----------------------------------------------------------------------===// // RewriteInDestinationPassingStyleOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2323,6 +2323,176 @@ .reifyResultShapes(b, reifiedReturnShapes); } +// Helper functions for softmax decomposition. +// @{ + +// Helper function to produce the iterator types (reduction or parallel) and +// affine maps for the iterators used in the decomposition of softmax. +// This method creates: +// If allParallel == true: +// - iterator type: {parallel, ..., parallel} +// - affine maps: +// -- identity with inputRank dimensions. +// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), +// where N == inputRank. +// +// If allParallel == false: +// - iterator type at dim(i) == parallel for i != \p dim and +// dim(dim) == reduction. +// - affine map: +// -- identity with inputRank dimensions. +// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), +// where N == inputRank. +static std::tuple, SmallVector> +computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, + int64_t dim, bool allParallel = false) { + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); + if (!allParallel) + iteratorTypes[dim] = utils::IteratorType::reduction; + MLIRContext *ctxt = builder.getContext(); + auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt); + SmallVector affineExprs; + for (int i = 0; i < inputRank; i++) { + if (i != dim) + affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt)); + } + auto reductionMap = + AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt); + SmallVector indexingMaps{identityMap, reductionMap}; + return std::make_tuple(iteratorTypes, indexingMaps); +} + +// Helper function to produce a linalg.generic that computes a reduction on +// dimension \p dim with the operation type \p T. +template +static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, + int64_t dim) { + auto inputType = cast(input.getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + auto [iteratorTypes, indexingMaps] = + computeIteratorTypesAndIndexingMaps(builder, inputRank, dim); + assert(indexingMaps.size() == 2 && + "We should have two maps: 1 for the input, 1 for the output"); + assert(indexingMaps[0].isIdentity() && "input map should be identity"); + + auto genericOp = builder.create( + loc, output.getType(), input, output, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value result = b.create(loc, args[0], args[1]); + b.create(loc, result); + }); + return genericOp.getResult(0); +} + +/// Produce a linalg generic that computes the second step of the softmax +/// decomposition: res = exp(input - max), where \p max is the max of \p input +/// on dimension \p dim. +static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, + Value max, Value output, int64_t dim) { + auto inputType = cast(input.getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( + builder, inputRank, dim, /*allParallel=*/true); + assert(indexingMaps.size() == 2 && "We should have one map for each input"); + assert(indexingMaps[0].isIdentity() && "input map should be identity"); + // Add the affine map for the output argument. + indexingMaps.push_back(indexingMaps[0]); + auto genericOp = builder.create( + loc, input.getType(), ValueRange{input, max}, output, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value diff = b.create(loc, args[0], args[1]); + Value result = b.create(loc, diff); + b.create(loc, result); + }); + return genericOp.getResult(0); +} + +/// Produce a linalg generic that computes the final step of the softmax +/// decomposition. +/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) { +/// yield n / d +/// } +static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, + Value denominator, Value output, int64_t dim) { + auto inputType = cast(numerator.getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( + builder, inputRank, dim, /*allParallel=*/true); + assert(indexingMaps.size() == 2 && + "We should have one map for each input (2)"); + assert(indexingMaps[0].isIdentity() && "Numerator map should be identity"); + // Add the affine map for the output tensor. + indexingMaps.push_back(indexingMaps[0]); + auto genericOp = builder.create( + loc, numerator.getType(), ValueRange{numerator, denominator}, output, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value result = b.create(loc, args[0], args[1]); + b.create(loc, result); + }); + return genericOp.getResult(0); +} +// @} End helper functions for softmax decomposition. + +/// Given an N-dimensional tensor x, this method converts +/// softmax(x) to the following sequence of operations: +/// +/// 1. Compute the max of x along dimension d. This results +/// in a N-1 dimensional tensor m. +/// m = max(x, dim = d) +/// +/// 2. Subtract a broadcasted m from x and exponentiate. This results in +/// a N dimensional tensor z. +/// z = exp(x - m) +/// +/// 3. Compute the sum of z along dimension d. This results in +/// a N-1 dimensional tensor l. +/// l = sum(z, dim = d) +/// +/// 4. Divide z and l. This gives the N-dimensional softmax. +/// softmax = z / l +/// +FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(*this); + Location loc = getLoc(); + Value input = getInput(); + ShapedType inputType = getInputOperandType(); + Type elementType = inputType.getElementType(); + int64_t reductionDim = getDimension(); + SmallVector dims = tensor::getMixedSizes(b, loc, input); + Value outputNd = b.create(loc, dims, elementType); + dims.erase(dims.begin() + reductionDim); + // Step 1: Compute max along dim. + Value output = b.create(loc, dims, elementType); + Value neutralForMaxF = + arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc); + Value neutralForMaxFInit = + b.create(loc, Value{neutralForMaxF}, output).result(); + Value max = + reduce(b, loc, input, neutralForMaxFInit, reductionDim); + + // Step 2: Subtract max from input and exponentiate. + Value numerator = + buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim); + + // Step 3: Compute sum along dim. + Value zero = + arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc); + Value zeroInit = b.create(loc, Value{zero}, output).result(); + Value denominator = + reduce(b, loc, numerator, zeroInit, reductionDim); + + // Step 4: Compute softmax. + Value result = + buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim); + return SmallVector{result}; +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -335,6 +335,38 @@ return emitDefaultSilenceableFailure(target); } +//===----------------------------------------------------------------------===// +// DecomposeInterfaceOp +//===----------------------------------------------------------------------===// + +// Decompose the target operation if it implements the AggregatedOpInterface. +// Push the decomposed operations (the ones that replaces the values produced by +// \p target) in the `results`. +DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + auto decomposableOp = dyn_cast(target); + if (!decomposableOp) { + failed(rewriter.notifyMatchFailure(target, + "payload is not a decomposable op")); + return emitDefaultSilenceableFailure(target); + } + + FailureOr> maybeNewResults = + decomposableOp.decomposeOperation(rewriter); + if (failed(maybeNewResults)) + return emitDefaultSilenceableFailure(target); + + rewriter.replaceOp(decomposableOp, *maybeNewResults); + for (Value val : *maybeNewResults) { + Operation *definition = val.getDefiningOp(); + if (definition) + results.push_back(definition); + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // EliminateLinalgOpAnchoredEmptyTensorsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + // CHECK-LABEL: @conv_2d_nhwc_hwcf // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32> @@ -199,8 +202,54 @@ return %0 : tensor } +func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> + return %1 : tensor<2x16x32xf32> +} + +// CHECK-LABEL: func.func @softmax( +//CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { +// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32 +// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> +// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", +// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32 +// CHECK: linalg.yield %[[D8]] : f32 +// CHECK: } -> tensor<2x16xf32> +// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = +// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32 +// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32 +// CHECK: linalg.yield %[[D9]] : f32 +// CHECK: } -> tensor<2x16x32xf32> +// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> +// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", +// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32 +// CHECK: linalg.yield %[[D8]] : f32 +// CHECK: } -> tensor<2x16xf32> +// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = +// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>) +// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) { +// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): +// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32 +// CHECK: linalg.yield %[[D8]] : f32 +// CHECK: } -> tensor<2x16x32xf32> +// CHECK: return %[[D7]] : tensor<2x16x32xf32> +// CHECK: } + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op + + %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op }