diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -32,6 +32,35 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +def InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> { + let summary = "operation to define a tensor of particular value"; + + let description = [{ + `linalg.init_tensor` is an operation that materializes a tensor of + the given scalar value. + + For now it is only used in the context of tiling to materialize a + tensor that the tile loop body updates the tiled results into + (using destructive updates). In reality, it would be better to + leave this tensor uninitialized, but that opens up issues of how + to handle "undef" values correctly. To side-step that for now, an + initialiazation value is specified, with the assumption that + entire tensor is overwritten during tiling. Proceed with caution + when using this operation. + }]; + + let arguments = + (ins AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); + + let results = (outs AnyTensor:$result); + + let assemblyFormat = [{ + $value attr-dict `:` type($value) `into` type($result) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + def Linalg_RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -65,6 +65,15 @@ for (unsigned i = 0, e = t.getRank(); i < e; ++i) res.push_back(b.create(loc, v, i)); } + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(getOperation()); + if (getNumInitTensors() == 0) { + for (Value v : getOperation()->getResults()) { + ShapedType t = v.getType().template cast(); + for (unsigned i = 0, e = t.getRank(); i < e; ++i) + res.push_back(b.create(loc, v, i)); + } + } return res; } @@ -147,6 +156,18 @@ // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. +//===----------------------------------------------------------------------===// +// InitTensorOp +//===----------------------------------------------------------------------===// +static LogicalResult verify(InitTensorOp op) { + if (op.value().getType() != + op.result().getType().cast().getElementType()) { + return op.emitError( + "mismatch in value type and element type of the result tensor"); + } + return success(); +} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -1707,6 +1728,11 @@ newOperands.push_back(fold ? tensorCastOp.getOperand() : v); newResultTypes.push_back(newOperands.back().getType()); } + if (linalgOp.getNumInitTensors() == 0) { + for (Value v : linalgOp.getOperation()->getResults()) { + newResultTypes.push_back(v.getType()); + } + } auto extraOperands = linalgOp.getAssumedNonShapedOperands(); newOperands.append(extraOperands.begin(), extraOperands.end()); // Clone op. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -220,9 +221,9 @@ static SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, - ValueRange operands, AffineMap map, ValueRange ivs, - ValueRange tileSizes, ValueRange allShapeSizes) { - assert(operands.size() == linalgOp.getShapedOperands().size()); + ValueRange operands, ValueRange nonInitResultTensors, + AffineMap map, ValueRange ivs, ValueRange tileSizes, + ValueRange allShapeSizes) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && @@ -242,11 +243,12 @@ subShapeSizes.push_back(size - std_constant_index(1)); } - auto *op = linalgOp.getOperation(); - SmallVector res; - res.reserve(op->getNumOperands()); - for (auto en : llvm::enumerate(operands)) { + res.reserve(operands.size() + nonInitResultTensors.size()); + SmallVector tiledOperands(operands.begin(), operands.end()); + tiledOperands.append(nonInitResultTensors.begin(), + nonInitResultTensors.end()); + for (auto en : llvm::enumerate(tiledOperands)) { Value shapedOp = en.value(); ShapedType shapedType = shapedOp.getType().cast(); unsigned rank = shapedType.getRank(); @@ -341,6 +343,7 @@ LoopIndexToRangeIndexMap loopIndexToRangeIndex; std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); + SmallVector iteratorTypes; for (auto attr : enumerate(op.iterator_types().cast().getValue())) { @@ -374,9 +377,20 @@ // 2. Create the tiled loops. LinalgOp res = op; SmallVector ivs, tensorResults; - auto initTensors = op.getInitTensors(); + SmallVector loopArgs = llvm::to_vector<1>(op.getInitTensors()); + SmallVector nonInitResultTensors; + if (op.getOperation()->getNumResults() != 0 && op.getNumInitTensors() == 0) { + for (ShapedType resultType : op.getOutputTensorTypes()) { + Attribute zeroAttr = b.getZeroAttr(resultType.getElementType()); + Value zeroValue = b.create(op.getLoc(), zeroAttr); + Value zeroTensor = + b.create(op.getLoc(), resultType, zeroValue); + nonInitResultTensors.push_back(zeroTensor); + } + } + loopArgs.append(nonInitResultTensors.begin(), nonInitResultTensors.end()); GenerateLoopNest::doit( - loopRanges, /*iterArgInitValues*/ initTensors, iteratorTypes, + loopRanges, /*iterArgInitValues*/ loopArgs, iteratorTypes, [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); @@ -391,20 +405,25 @@ else interchangedIvs.assign(ivs.begin(), ivs.end()); - assert(op.getNumInitTensors() == iterArgs.size() && + assert((op.getNumInitTensors() + nonInitResultTensors.size() == + iterArgs.size()) && "num init tensors must match number of loop iter arguments"); // This uses knowledge about position of the init tensor in the list // of operands. auto operands = llvm::to_vector<4>(op.getShapedOperands()); - std::copy(iterArgs.begin(), iterArgs.end(), - operands.begin() + op.getNumInputsAndOutputBuffers()); + if (op.getNumInitTensors()) { + std::copy(iterArgs.begin(), iterArgs.end(), + operands.begin() + op.getNumInputsAndOutputBuffers()); + } + SmallVector localNonInitResultTensors = {}; + if (!nonInitResultTensors.empty()) { + localNonInitResultTensors.append( + iterArgs.begin() + op.getNumInitTensors(), iterArgs.end()); + } - SmallVector tiledOperands = - makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap, - interchangedIvs, tileSizes, allShapeSizes); - auto nonShapedOperands = op.getAssumedNonShapedOperands(); - tiledOperands.append(nonShapedOperands.begin(), - nonShapedOperands.end()); + SmallVector tiledValues = makeTiledShapes( + b, loc, op, operands, localNonInitResultTensors, + shapeSizesToLoopsMap, interchangedIvs, tileSizes, allShapeSizes); // If LinalgOp has results, they must all be tied to init tensors. // We enforce this to ensure all tiled ops have been rewritten in @@ -414,24 +433,34 @@ // This would not be the case with a special terminator op that // generates the whole tensor (instead of inserting a subtensor). But // the generator-based abstraction has other issues. - assert(op.getNumInitTensors() == op->getNumResults() && - "expected same number of init tensors as number of results"); + assert((op.getNumInitTensors() == 0 || + op.getNumInitTensors() == op->getNumResults()) && + "expected number of init tensors to be zero or same as number " + "of results"); // Handle init tensor operands. // This uses knowledge about position of the init tensor in the list // of operands. // TODO: InterfaceAdaptor ? SmallVector resultTensorTypes; - for (auto idx : llvm::seq(0, op.getNumInitTensors())) + for (auto idx : + llvm::seq(0, op.getOperation()->getNumResults())) resultTensorTypes.push_back( - tiledOperands[op.getNumInputsAndOutputBuffers() + idx].getType()); + tiledValues[op.getNumInputsAndOutputBuffers() + idx].getType()); - res = op.clone(b, loc, resultTensorTypes, tiledOperands); + SmallVector clonedOpOperands = llvm::to_vector<4>( + ArrayRef(tiledValues).take_front(op.getNumShapedOperands())); + auto nonShapedOperands = op.getAssumedNonShapedOperands(); + clonedOpOperands.append(nonShapedOperands.begin(), + nonShapedOperands.end()); + res = op.clone(b, loc, resultTensorTypes, clonedOpOperands); // Insert a subtensor_insert for each init subtensor. - for (unsigned idx = 0, e = op.getNumInitTensors(); idx != e; ++idx) { + for (unsigned idx = 0, + e = op.getNumInitTensors() + nonInitResultTensors.size(); + idx != e; ++idx) { Value initTensor = - tiledOperands[op.getNumInputsAndOutputBuffers() + idx]; + tiledValues[op.getNumInputsAndOutputBuffers() + idx]; if (auto subtensor = initTensor.getDefiningOp()) { tensorResults.push_back(b.create( loc, subtensor.source().getType(), res->getResult(idx), @@ -581,10 +610,10 @@ static void insertTilingPatterns(OwningRewritePatternList &patterns, const LinalgTilingOptions &options, MLIRContext *ctx) { - RewritePatternList< + RewritePatternList::insert(patterns, options, ctx); + >::insert(patterns, options, ctx); } static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -125,15 +125,8 @@ if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); - // If LinalgOp has results, they must all be tied to init tensors. - // We enforce this to ensure all tiled ops have been rewritten in - // "init tensor" form. This ensures tiling has anchor values into which to - // subtensor / subtensor_insert. Otherwise tiling would need to allocate which - // is not acceptable. - // This would not be the case with a special terminator op that generates the - // whole tensor (instead of inserting a subtensor). But the generator-based - // abstraction has other issues. - if (linalgOp.getNumInitTensors() != linalgOp->getNumResults()) + if (linalgOp.getNumInitTensors() != 0 && + linalgOp.getNumInitTensors() != linalgOp->getNumResults()) return failure(); Optional res = tileLinalgOp(rewriter, linalgOp, options); diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -split-input-file %s | FileCheck %s -// | mlir-opt | FileCheck %s +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s // TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered. // @@ -698,3 +697,31 @@ // CHECK-LABEL: func @memref_reshape_zero_dim // CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref // CHECK: linalg.reshape %{{.*}} [] : memref into memref<1x1xf32> + +// ----- + +func @fill_tensor(%arg0 : f32, %arg1 : i32, %arg2 : vector<2x3xf32>) +{ + %0 = linalg.init_tensor %arg0 : f32 into tensor + %1 = linalg.init_tensor %arg0 : f32 into tensor + %2 = linalg.init_tensor %arg0 : f32 into tensor<42x21xf32> + %3 = linalg.init_tensor %arg1 : i32 into tensor + %4 = linalg.init_tensor %arg2 : + vector<2x3xf32> into tensor> + return +} +// CHECK-LABEL: func @fill_tensor +// CHECK: linalg.init_tensor %{{.+}} : f32 into tensor +// CHECK: linalg.init_tensor %{{.+}} : f32 into tensor +// CHECK: linalg.init_tensor %{{.+}} : f32 into tensor<42x21xf32> +// CHECK: linalg.init_tensor %{{.+}} : i32 into tensor +// CHECK: linalg.init_tensor %{{.+}} : vector<2x3xf32> into tensor> + +// ----- + +func @fill_tensor_mismatch(%arg0 : f32) +{ + // expected-error @+1 {{mismatch in value type and element type of the result tensor}} + %0 = linalg.init_tensor %arg0 : f32 into tensor + return +} diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -mlir-disable-threading=true | FileCheck %s +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -split-input-file | FileCheck %s // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -26,3 +26,40 @@ // CHECK: return %[[TD0]] : tensor return %0 : tensor } + +// ----- + +func @generic_op_tensors( + %arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d1, d0)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg2 : f32, %arg3: f32): + %1 = addf %arg2, %arg3 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @generic_op_tensors +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[TD0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC0:.+]] = %[[INIT]]) -> (tensor) { +// CHECK: %[[TD1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC1:.+]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[TD2:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC2:.+]] = %[[TC1]]) -> (tensor) { +// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][{{.+}}] : tensor to tensor +// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][{{.+}}] : tensor to tensor +// CHECK: %[[STRETURN:.+]] = linalg.generic +// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) +// CHECK: %[[TD:.+]] = subtensor_insert %[[STRETURN]] into %[[TC2]] +// CHECK: scf.yield %[[TD]] +// CHECK: } +// CHECK: scf.yield %[[TD2]] +// CHECK: } +// CHECK: scf.yield %[[TD1]] +// CHECK: } +// CHECK: return %[[TD0]]