diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -85,6 +85,18 @@ /// Performs folding of any operand of `op` if it comes from a tensor::CastOp /// that can be folded. LogicalResult foldTensorCast(Operation *op); + +/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and +/// appropriate sizes to reduce the rank of `tensor` to `targetType`. +Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, + Value tensor, + RankedTensorType targetType); + +/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and +/// appropriate sizes to increase the rank of `tensor` to `dest`. +Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, + Value tensor, Value dest); + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -202,6 +202,57 @@ #include "mlir/IR/BuiltinTypeInterfaces.h.inc" namespace mlir { + +//===----------------------------------------------------------------------===// +// RankedTensorType +//===----------------------------------------------------------------------===// + +/// This is a builder type that keeps local references to arguments. Arguments +/// that are passed into the builder must out-live the builder. +class RankedTensorType::Builder { +public: + /// Build from another RankedTensorType. + explicit Builder(RankedTensorType other) + : shape(other.getShape()), elementType(other.getElementType()), + encoding(other.getEncoding()) {} + + /// Build from scratch. + Builder(ArrayRef shape, Type elementType, Attribute encoding) + : shape(shape), elementType(elementType), encoding(encoding) {} + + Builder &setShape(ArrayRef newShape) { + shape = newShape; + return *this; + } + + Builder &setElementType(Type newElementType) { + elementType = newElementType; + return *this; + } + + Builder &setEncoding(Attribute newEncoding) { + encoding = newEncoding; + return *this; + } + + /// Create a new RankedTensor by erasing a dim from shape. + // Note: the newly created type has ownership of a new shape vector. + RankedTensorType dropDim(unsigned dim) { + SmallVector newShape(shape.begin(), shape.end()); + newShape.erase(newShape.begin() + dim); + return setShape(newShape); + } + + operator RankedTensorType() { + return RankedTensorType::get(shape, elementType, encoding); + } + +private: + ArrayRef shape; + Type elementType; + Attribute encoding; +}; + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -701,6 +701,11 @@ return $_get(elementType.getContext(), shape, elementType, encoding); }]> ]; + let extraClassDeclaration = [{ + /// This is a builder type that keeps local references to arguments. + /// Arguments that are passed into the builder must out-live the builder. + class Builder; + }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -852,7 +852,6 @@ auto filterType = filter.getType().dyn_cast(); auto outputType = output.getType().dyn_cast(); - auto inputShape = inputType.getShape(); auto filterShape = filterType.getShape(); auto outputShape = outputType.getShape(); @@ -860,52 +859,47 @@ // of size 1. Other cases can rely on tiling to reduce to such cases. int64_t fhSize = filterShape[0], fwSize = filterShape[1]; int64_t ohSize = outputShape[1], owSize = outputShape[2]; - if (!(fhSize == 1 && ohSize == 1) && !(fwSize == 1 && owSize == 1)) + bool removeH = (fhSize == 1 && ohSize == 1); + bool removeW = (fwSize == 1 && owSize == 1); + if (!removeH && !removeW) return failure(); - bool removeH = ohSize == 1; // Get new shapes and types for all operands by removing the size-1 // dimension. + using RTTBuilder = RankedTensorType::Builder; + auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + auto newFilterType = RTTBuilder(filterType).dropDim((removeH ? 0 : 1)); + auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2); - SmallVector newInputShape{ - inputShape[0], inputShape[removeH ? 2 : 1], inputShape[3]}; - auto newInputType = RankedTensorType::get( - newInputShape, inputType.getElementType(), inputType.getEncoding()); - - SmallVector newFilterShape{filterShape[removeH ? 1 : 0], - filterShape[2], filterShape[3]}; - auto newFilterType = RankedTensorType::get( - newFilterShape, filterType.getElementType(), filterType.getEncoding()); - - SmallVector newOutputShape{ - outputShape[0], outputShape[removeH ? 2 : 1], outputShape[3]}; - auto newOutputType = RankedTensorType::get( - newOutputShape, outputType.getElementType(), outputType.getEncoding()); - - SmallVector ioReshapeIndices = {{0}, {1, 2}, {3}}; - SmallVector fReshapeIndices = {{0, 1}, {2}, {3}}; - - // Reshape all operands for 1-D convolution. + // Rank-reduce operands. Location loc = convOp.getLoc(); - Value newInput = rewriter.create( - loc, newInputType, input, ioReshapeIndices); - Value newFilter = rewriter.create( - loc, newFilterType, filter, fReshapeIndices); - Value newOutput = rewriter.create( - loc, newOutputType, output, ioReshapeIndices); - - // We need to shrink the strides and dilations too. - auto stride = convOp.strides().getValues()[removeH ? 1 : 0]; - auto stridesAttr = rewriter.getI64VectorAttr(stride); - auto dilation = convOp.dilations().getValues()[removeH ? 1 : 0]; - auto dilationsAttr = rewriter.getI64VectorAttr(dilation); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, filter, newFilterType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = llvm::to_vector<4>(convOp.strides().getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); + + auto dilations = + llvm::to_vector<4>(convOp.dilations().getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); auto conv1DOp = rewriter.create( loc, newOutputType, ValueRange{newInput, newFilter}, ValueRange{newOutput}, stridesAttr, dilationsAttr); - rewriter.replaceOpWithNewOp( - convOp, outputType, conv1DOp.getResult(0), ioReshapeIndices); + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + return success(); }; }; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1070,6 +1070,27 @@ return OpFoldResult(); } +Value mlir::tensor::createCanonicalRankReducingExtractSliceOp( + OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) { + auto rankedTensorType = tensor.getType().cast(); + unsigned rank = rankedTensorType.getRank(); + auto shape = rankedTensorType.getShape(); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector sizes; + for (unsigned i = 0, e = rank; i < e; ++i) { + OpFoldResult dim; + if (rankedTensorType.isDynamicDim(i)) + dim = b.createOrFold( + loc, tensor, b.create(loc, i)); + else + dim = b.getIndexAttr(shape[i]); + sizes.push_back(dim); + } + SmallVector strides(rank, b.getIndexAttr(1)); + return b.createOrFold(loc, targetType, tensor, + offsets, sizes, strides); +} + //===----------------------------------------------------------------------===// // InsertSliceOp //===----------------------------------------------------------------------===// @@ -1309,6 +1330,29 @@ InsertSliceOpSourceCastInserter>(context); } +Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, + Location loc, + Value tensor, + Value dest) { + auto rankedTensorType = dest.getType().cast(); + unsigned rank = rankedTensorType.getRank(); + auto shape = rankedTensorType.getShape(); + SmallVector offsets(rank, b.getIndexAttr(0)); + SmallVector sizes; + for (unsigned i = 0, e = rank; i < e; ++i) { + OpFoldResult dim; + if (rankedTensorType.isDynamicDim(i)) + dim = b.createOrFold( + loc, dest, b.create(loc, i)); + else + dim = b.getIndexAttr(shape[i]); + sizes.push_back(dim); + } + SmallVector strides(rank, b.getIndexAttr(1)); + return b.createOrFold(loc, tensor, dest, offsets, + sizes, strides); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir --- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir +++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir @@ -10,21 +10,20 @@ return %0 : tensor<4x1x2x8xf32> } -// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]] -// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x6x3xf32> into tensor<4x6x3xf32> -// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]] -// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor<1x2x3x8xf32> into tensor<2x3x8xf32> -// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]] -// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x1x2x8xf32> into tensor<4x2x8xf32> +// CHECK: %[[INPUT_1D:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [4, 1, 6, 3] [1, 1, 1, 1] : tensor<4x1x6x3xf32> to tensor<4x6x3xf32> +// CHECK: %[[FILTER_1D:.+]] = tensor.extract_slice %[[FILTER]] +// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [1, 2, 3, 8] [1, 1, 1, 1] : tensor<1x2x3x8xf32> to tensor<2x3x8xf32> +// CHECK: %[[INIT_1D:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [4, 1, 2, 8] [1, 1, 1, 1] : tensor<4x1x2x8xf32> to tensor<4x2x8xf32> // CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf // CHECK-SAME: dilations = dense<3> : vector<1xi64> // CHECK-SAME: strides = dense<2> : vector<1xi64> // CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor<4x6x3xf32>, tensor<2x3x8xf32>) // CHECK-SAME: outs(%[[INIT_1D]] : tensor<4x2x8xf32>) -// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]] -// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32> +// CHECK: %[[CONV_2D:.+]] = tensor.insert_slice %[[CONV_1D]] into %[[INIT]] +// CHECK-SAME{LITERAL}: [0, 0, 0, 0] [4, 1, 2, 8] [1, 1, 1, 1] : tensor<4x2x8xf32> into tensor<4x1x2x8xf32> // CHECK: return %[[CONV_2D]] - // ----- // CHECK-LABEL: func @conv2d_nhwc_qxqx1xq_tensor @@ -37,19 +36,23 @@ return %0 : tensor } -// CHECK: %[[INPUT_1D:.+]] = linalg.tensor_collapse_shape %[[INPUT]] -// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor into tensor -// CHECK: %[[FILTER_1D:.+]] = linalg.tensor_collapse_shape %[[FILTER]] -// CHECK-SAME{LITERAL}: [[0, 1], [2], [3]] : tensor into tensor -// CHECK: %[[INIT_1D:.+]] = linalg.tensor_collapse_shape %[[INIT]] -// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[INPUT_1D:.+]] = tensor.extract_slice %[[INPUT]] +// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : +// CHECK-SAME: tensor to tensor +// CHECK: %[[FILTER_1D:.+]] = tensor.extract_slice %[[FILTER]] +// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : +// CHECK-SAME: tensor to tensor +// CHECK: %[[INIT_1D:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : +// CHECK-SAME: tensor to tensor // CHECK: %[[CONV_1D:.+]] = linalg.conv_1d_nwc_wcf // CHECK-SAME: dilations = dense<2> : vector<1xi64> // CHECK-SAME: strides = dense<3> : vector<1xi64> // CHECK-SAME: ins(%[[INPUT_1D]], %[[FILTER_1D]] : tensor, tensor) // CHECK-SAME: outs(%[[INIT_1D]] : tensor) -// CHECK: %[[CONV_2D:.+]] = linalg.tensor_expand_shape %[[CONV_1D]] -// CHECK-SAME{LITERAL}: [[0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[CONV_2D:.+]] = tensor.insert_slice %[[CONV_1D]] into %[[INIT]] +// CHECK-SAME: [0, 0, 0, 0] [%{{.*}}, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : +// CHECK-SAME: tensor into tensor // CHECK: return %[[CONV_2D]] // -----