diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" // Base class for Linalg dialect ops that do not correspond to library calls. @@ -91,7 +92,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", [DestinationStyleOpInterface, PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Softmax operator"; let description = [{ linalg.softmax computes a numerically stable version of softmax. @@ -104,6 +110,12 @@ This is an aggregate linalg operation that further reduces to a small DAG of structured operations. + + Warning: Regarding the tiling capabilities, the implementation doesn't + check that the provided dimensions make sense. This is the responsability + of the transformation calling the tiling to ensure that the provided + sizes for each dimension make sense with respect to the semantic of + softmax. }]; let arguments = (ins AnyShaped:$input, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -46,6 +47,41 @@ using namespace mlir; using namespace mlir::linalg; +/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`. +static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + auto type = cast(v.getType()); + if (!type.isDynamicDim(dim)) + return builder.getIndexAttr(type.getDimSize(dim)); + + return getAsOpFoldResult( + TypeSwitch(v.getType()) + .Case([&](RankedTensorType t) -> Value { + return builder.create(loc, v, dim); + }) + .Case([&](MemRefType t) -> Value { + return builder.create(loc, v, dim); + })); +} + +/// Returns a memref.subview or a tensor.extract_slice based on the type of the +/// `source`. +static Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType t) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Case([&](MemRefType type) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Default([&](Type t) { return nullptr; }); +} + //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// @@ -2182,6 +2218,62 @@ return success(); } +SmallVector SoftmaxOp::getIterationDomain(OpBuilder &builder) { + int64_t operandRank = getInputOperandRank(); + SmallVector loopBounds(operandRank); + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = getInput(); + for (auto dim : llvm::seq(0, operandRank)) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, source, dim); + loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector SoftmaxOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(getInputOperandRank(), + utils::IteratorType::parallel); + iteratorTypes[getDimension()] = utils::IteratorType::reduction; + return iteratorTypes; +} + +FailureOr +SoftmaxOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + int64_t rank = getInputOperandRank(); + auto oneAttr = builder.getI64IntegerAttr(1); + SmallVector strides(rank, oneAttr); + SmallVector tiledOperands; + tiledOperands.emplace_back( + getSlice(builder, getLoc(), getInput(), offsets, sizes, strides)); + tiledOperands.emplace_back( + getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides)); + + SmallVector resultTypes; + if (hasTensorSemantics()) + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + +LogicalResult SoftmaxOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + if (resultNumber == 0) { + resultOffsets.assign(offsets.begin(), offsets.end()); + resultSizes.assign(sizes.begin(), sizes.end()); + return success(); + } + return failure(); +} + // cast(dynamic) -> static. LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -555,7 +555,13 @@ auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], cast(sliceOpToTile->getResult(0).getType()).getShape()); - assert(succeeded(maybeRankReduced) && "unexpected shape"); + if (failed(maybeRankReduced)) { + diag.attachNote(producerOp->getLoc()) + << "shape types don't match (missing canonicalization?):\nTiledOp: " + << tileAndFuseResult->tiledValues[0] + << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n'; + return {}; + } rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); // Add new outputs to containing op, if required diff --git a/mlir/test/Dialect/Linalg/tile-softmax.mlir b/mlir/test/Dialect/Linalg/tile-softmax.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-softmax.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize --split-input-file | FileCheck %s + +// Check that we can tile softmax on tensors. +// The tiling here is 2x3. +// So the shape used in the inner loop should be 2x3x256, however since 3 +// doesn't divide the second dimension (64), we should see a '?' in the shape. +// The actual size, used through extract_slice/insert_slice, should come from a +// `min(64 - current iteration index, 3)` + +// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)> +// CHECK-LABEL: func.func @softmax( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[TENSOR_EMPTY:.*]] = tensor.empty() : tensor<16x64x256xf32> +// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C16]] step %[[C2]] iter_args(%[[VAL_9:.*]] = %[[TENSOR_EMPTY]]) -> (tensor<16x64x256xf32>) { +// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C64]] step %[[C3]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]]) -> (tensor<16x64x256xf32>) { +// CHECK: %[[VAL_13:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_11]]) +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32> +// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32> +// CHECK: %[[VAL_16:.*]] = linalg.softmax dimension(1) ins(%[[VAL_14]] : tensor<2x?x256xf32>) outs(%[[VAL_15]] : tensor<2x?x256xf32>) -> tensor<2x?x256xf32> +// CHECK: %[[VAL_17:.*]] = tensor.insert_slice %[[VAL_16]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<2x?x256xf32> into tensor<16x64x256xf32> +// CHECK: scf.yield %[[VAL_17]] : tensor<16x64x256xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_18:.*]] : tensor<16x64x256xf32> +// CHECK: } +// CHECK: return %[[VAL_19:.*]] : tensor<16x64x256xf32> +// CHECK: } +func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { + %0 = tensor.empty() : tensor<16x64x256xf32> + %1 = linalg.softmax + dimension(1) ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> + return %1 : tensor<16x64x256xf32> +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop:2 = transform.structured.tile %0 [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +} + +// ----- + +// Test the softmax tiling interface with the tile_to_forall_op transform and +// check that it composes properly with the fuse transform. +// This should sink the linalg.generic inside the scf.forall and run that +// generic on 2x4x256 tensors (2==16/8, 4==64/16). + +// CHECK: #[[$TIMES2_MAP:.*]] = affine_map<(d0) -> (d0 * 2)> +// CHECK: #[[$TIMES4_MAP:.*]] = affine_map<(d0) -> (d0 * 4)> +// CHECK-LABEL: func.func @softmax_tile_n_fuse( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<16x64x256xf32> +// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x64x256xf32> +// CHECK: %[[VAL_4:.*]] = scf.forall (%[[VAL_5:.*]], %[[VAL_6:.*]]) in (8, 16) shared_outs(%[[VAL_7:.*]] = %[[VAL_3]]) -> (tensor<16x64x256xf32>) { +// CHECK: %[[VAL_8:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]]) +// CHECK: %[[VAL_9:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]]) +// CHECK: %[[VAL_10:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]]) +// CHECK: %[[VAL_11:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]]) +// CHECK: %[[VAL_12:.*]] = affine.apply #[[$TIMES2_MAP]](%[[VAL_5]]) +// CHECK: %[[VAL_13:.*]] = affine.apply #[[$TIMES4_MAP]](%[[VAL_6]]) +// CHECK: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32> +// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_2]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32> +// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_14]] : tensor<2x4x256xf32>) outs(%[[VAL_15]] : tensor<2x4x256xf32>) { +// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_18]], %[[VAL_1]] : f32 +// CHECK: linalg.yield %[[VAL_19]] : f32 +// CHECK: } -> tensor<2x4x256xf32> +// CHECK: %[[VAL_20:.*]] = tensor.extract_slice %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x4x256xf32> +// CHECK: %[[VAL_21:.*]] = linalg.softmax dimension(1) ins(%[[VAL_22:.*]] : tensor<2x4x256xf32>) outs(%[[VAL_20]] : tensor<2x4x256xf32>) -> tensor<2x4x256xf32> +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[VAL_21]] into %[[VAL_7]]{{\[}}%[[VAL_8]], %[[VAL_9]], 0] [2, 4, 256] [1, 1, 1] : tensor<2x4x256xf32> into tensor<16x64x256xf32> +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_23:.*]] : tensor<16x64x256xf32> +// CHECK: } + +func.func @softmax_tile_n_fuse(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { + %empty = tensor.empty() : tensor<16x64x256xf32> + %cst = arith.constant 1.000000e+00 : f32 + %eltwise = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"] + } + ins(%arg0 : tensor<16x64x256xf32>) + outs(%empty : tensor<16x64x256xf32>) { + ^bb0(%arg2: f32, %arg3: f32): + %arg3Plus1 = arith.addf %arg3, %cst : f32 + linalg.yield %arg3Plus1 : f32 + } -> tensor<16x64x256xf32> + + %0 = tensor.empty() : tensor<16x64x256xf32> + %1 = linalg.softmax + dimension(1) ins(%eltwise : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> + return %1 : tensor<16x64x256xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // Tile the root. + %forall_op, %tiled_op = transform.structured.tile_to_forall_op %0 num_threads [8, 16] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Fuse all producers. + %1 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.fuse_into_containing_op %1 into %forall_op + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) +} +// ----- + +// Same as the previous test but on memrefs. + +// CHECK: #[[$MIN_MAP:.*]] = affine_map<(d0) -> (-d0 + 64, 3)> +// CHECK-LABEL: func.func @softmax_memref( +// CHECK-SAME: %[[VAL_0:.*]]: memref<16x64x256xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<16x64x256xf32>) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK: scf.for %[[VAL_7:.*]] = %[[C0]] to %[[C16]] step %[[C2]] { +// CHECK: scf.for %[[VAL_8:.*]] = %[[C0]] to %[[C64]] step %[[C3]] { +// CHECK: %[[VAL_9:.*]] = affine.min #[[$MIN_MAP]](%[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>> +// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_7]], %[[VAL_8]], 0] [2, %[[VAL_9]], 256] [1, 1, 1] : memref<16x64x256xf32> to memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>> +// CHECK: linalg.softmax dimension(1) ins(%[[VAL_10]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>) outs(%[[VAL_11]] : memref<2x?x256xf32, strided<[16384, 256, 1], offset: ?>>) +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } +func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) { + linalg.softmax + dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>) + return +} + +transform.sequence failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop:2 = transform.structured.tile %0 [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) +}