diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1591,47 +1591,6 @@ let summary = "floating point division operation"; } -//===----------------------------------------------------------------------===// -// DynamicTensorFromElementsOp -//===----------------------------------------------------------------------===// - -def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements", - [RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> { - string summary = "Creates a dynamically sized tensor from elements"; - string description = [{ - This operation creates a dynamically sized tensor with elements of any type. - It expects one index operand per dynamic extent of the result tensor. - - The body region defines the tensor's elements. It takes index operands as - its region arguments that span the index space. The element at the given - position is yielded with the `yield` operation (see `YieldOp`). There is - no defined ordering to the invocations of the body. It is conceptually - a "parallel map" operation. - - Example: - - ```mlir - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - ... - yield %elem : f32 - } : tensor - ``` - }]; - - let arguments = (ins Variadic:$dynamicExtents); - let results = (outs AnyRankedTensor:$result); - let regions = (region SizedRegion<1>:$body); - - let builders = [ - // Build op and populate its body per callback function. - OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents, - "function_ref")>, - ]; - - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // ExpOp //===----------------------------------------------------------------------===// @@ -1672,46 +1631,6 @@ let summary = "base-2 exponential of the specified value"; } -//===----------------------------------------------------------------------===// -// TensorFromElementsOp -//===----------------------------------------------------------------------===// - -def TensorFromElementsOp : Std_Op<"tensor_from_elements", [ - NoSideEffect, - TypesMatchWith<"operand types match result element type", - "result", "elements", "SmallVector(" - "$_self.cast().getDimSize(0), " - "$_self.cast().getElementType())"> - ]> { - string summary = "tensor from elements operation."; - string description = [{ - Create a 1D tensor from a range of same-type arguments. - - Example: - - ```mlir - tensor_from_elements(i_1, ..., i_N) : tensor - ``` - }]; - - let arguments = (ins Variadic:$elements); - let results = (outs 1DTensorOf<[AnyType]>:$result); - - let assemblyFormat = "$elements attr-dict `:` type($result)"; - - // This op is fully verified by its traits. - let verifier = ?; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>, - // Special case builder for when `elements` has size >=1. - OpBuilderDAG<(ins "ValueRange":$elements)> - ]; - - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// @@ -3837,24 +3756,6 @@ let hasCanonicalizer = 1; } -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// - -def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator, - HasParent<"DynamicTensorFromElementsOp">]> { - let summary = "Yield a value from a region"; - let description = [{ - This operation is used to yield a single value from a within a region. It - is used to create dynamically sized tensors - (see `DynamicTensorFromElementsOp`). - }]; - - let arguments = (ins AnyType:$value); - let assemblyFormat = "$value attr-dict `:` type($value)"; - let verifier = ?; -} - //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -13,6 +13,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -10,6 +10,7 @@ #define TENSOR_OPS include "mlir/Dialect/Tensor/IR/TensorBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" class Tensor_Op traits = []> @@ -105,4 +106,109 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// FromElementsOp +//===----------------------------------------------------------------------===// + +def Tensor_FromElementsOp : Tensor_Op<"from_elements", [ + NoSideEffect, + TypesMatchWith<"operand types match result element type", + "result", "elements", "SmallVector(" + "$_self.cast().getDimSize(0), " + "$_self.cast().getElementType())"> + ]> { + string summary = "tensor from elements operation."; + string description = [{ + Create a 1D tensor from a range of same-type arguments. + + Example: + + ```mlir + tensor.from_elements(i_1, ..., i_N) : tensor + ``` + }]; + + let arguments = (ins Variadic:$elements); + let results = (outs 1DTensorOf<[AnyType]>:$result); + + let assemblyFormat = "$elements attr-dict `:` type($result)"; + + // This op is fully verified by its traits. + let verifier = ?; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>, + // Special case builder for when `elements` has size >=1. + OpBuilderDAG<(ins "ValueRange":$elements)> + ]; + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// GenerateOp +//===----------------------------------------------------------------------===// + +def Tensor_GenerateOp : Tensor_Op<"generate", + [RecursiveSideEffects, + SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { + string summary = "Creates a dynamically sized tensor from elements"; + string description = [{ + This operation creates a dynamically sized tensor with elements of any type. + It expects one index operand per dynamic extent of the result tensor. + + The body region defines the tensor's elements. It takes index operands as + its region arguments that span the index space. The element at the given + position is yielded with the `yield` operation (see `YieldOp`). There is + no defined ordering to the invocations of the body. It is conceptually + a "parallel map" operation. + + Example: + + ```mlir + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + ... + yield %elem : f32 + } : tensor + ``` + }]; + + let arguments = (ins Variadic:$dynamicExtents); + let results = (outs AnyRankedTensor:$result); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "$dynamicExtents $body attr-dict `:` type($result)"; + + let builders = [ + // Build op and populate its body per callback function. + OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents, + "function_ref")>, + ]; + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +def Tensor_YieldOp : Tensor_Op<"yield", + [NoSideEffect, ReturnLike, Terminator, + HasParent<"::mlir::tensor::GenerateOp">]> { + let summary = "Yield a value from a region"; + let description = [{ + This operation is used to yield a single value from a within a region. It + is used to create dynamically sized tensors + (see `tensor.generate` op). + }]; + + let arguments = (ins AnyType:$value); + let assemblyFormat = "$value attr-dict `:` type($value)"; + // Dummy builder to appease code in templated ensureTerminator that + // GenerateOp's auto-generated parser calls. + let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>]; + let verifier = ?; +} + #endif // TENSOR_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -14,6 +14,7 @@ def TensorBufferize : FunctionPass<"tensor-bufferize"> { let summary = "Bufferize the `tensor` dialect"; let constructor = "mlir::createTensorBufferizePass()"; + let dependentDialects = ["scf::SCFDialect"]; } #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -20,6 +20,7 @@ MLIREDSC MLIRIR MLIRShape + MLIRTensor MLIRPass MLIRSCF MLIRTransforms diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -113,7 +113,7 @@ Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value outputDimension = args[0]; @@ -151,7 +151,7 @@ greaterRankOperandExtent); b.create(loc, broadcastedExtent); }); - b.create(loc, ifOp.getResult(0)); + b.create(loc, ifOp.getResult(0)); }); return success(); } @@ -184,7 +184,7 @@ } Type indexTy = rewriter.getIndexType(); Value tensor = - rewriter.create(loc, indexTy, extentOperands); + rewriter.create(loc, indexTy, extentOperands); Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); rewriter.replaceOpWithNewOp(op, resultTy, tensor); return success(); @@ -503,7 +503,7 @@ if (op.getType().isa()) return failure(); - // For ranked tensor arguments, lower to `tensor_from_elements`. + // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); ShapeOfOp::Adaptor transformed(operands); Value tensor = transformed.arg(); @@ -526,22 +526,22 @@ } // Materialize extent tensor. - Value staticExtentTensor = rewriter.create( + Value staticExtentTensor = rewriter.create( loc, rewriter.getIndexType(), extentValues); rewriter.replaceOpWithNewOp(op, op.getType(), staticExtentTensor); return success(); } - // Lower to `dynamic_tensor_from_elements` otherwise. + // Lower to `tensor.generate` otherwise. auto *ctx = rewriter.getContext(); Value rank = rewriter.create(loc, tensor); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getExtentTensorType(ctx), ValueRange{rank}, [&](OpBuilder &b, Location loc, ValueRange args) { Value dim = args.front(); Value extent = b.create(loc, tensor, dim); - b.create(loc, extent); + b.create(loc, extent); }); return success(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1392,9 +1392,8 @@ return getResult(); } - // Fold dim to the operand of dynamic_tensor_from_elements. - if (auto fromElements = - dyn_cast_or_null(definingOp)) { + // Fold dim to the operand of tensor.generate. + if (auto fromElements = dyn_cast_or_null(definingOp)) { auto resultType = fromElements.getResult().getType().cast(); // The case where the type encodes the size of the dimension is handled @@ -1734,258 +1733,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// DynamicTensorFromElementsOp -//===----------------------------------------------------------------------===// - -static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser, - OperationState &result) { - // Parse operands. - SmallVector dynamicExtents; - Type indexTy = parser.getBuilder().getIndexType(); - if (parser.parseOperandList(dynamicExtents) || - parser.resolveOperands(dynamicExtents, indexTy, result.operands)) - return failure(); - - // Parse body. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, {}, {})) - return failure(); - - // Parse result type. - Type resultType; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType)) - return failure(); - result.addTypes(resultType); - - return success(); -} - -static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) { - p << "dynamic_tensor_from_elements " << op.dynamicExtents(); - p.printRegion(op.body()); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getType(); -} - -static LogicalResult verify(DynamicTensorFromElementsOp op) { - // Ensure that the tensor type has as many dynamic dimensions as are specified - // by the operands. - RankedTensorType resultTy = op.getType().cast(); - if (op.getNumOperands() != resultTy.getNumDynamicDims()) - return op.emitError("must have as many index operands as dynamic extents " - "in the result type"); - - // Ensure that region arguments span the index space. - if (!llvm::all_of(op.body().getArgumentTypes(), - [](Type ty) { return ty.isIndex(); })) - return op.emitError("all body arguments must be index"); - if (op.body().getNumArguments() != resultTy.getRank()) - return op.emitError("must have one body argument per input dimension"); - - // Ensure that the region yields an element of the right type. - auto yieldOp = - llvm::cast(op.body().getBlocks().front().getTerminator()); - if (yieldOp.value().getType() != resultTy.getElementType()) - return op.emitOpError( - "body must be terminated with a `yield` operation of the tensor " - "element type"); - - return success(); -} - -void DynamicTensorFromElementsOp::build( - OpBuilder &b, OperationState &result, Type resultTy, - ValueRange dynamicExtents, - function_ref bodyBuilder) { - build(b, result, resultTy, dynamicExtents); - - // Build and populate body. - OpBuilder::InsertionGuard guard(b); - Region *bodyRegion = result.regions.front().get(); - auto rank = resultTy.cast().getRank(); - SmallVector argumentTypes(rank, b.getIndexType()); - Block *bodyBlock = - b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); - bodyBuilder(b, result.location, bodyBlock->getArguments()); -} - -namespace { - -/// Canonicalizes dynamic_tensor_from_elements operations with a constant -/// operand into the equivalent operation with the operand expressed in the -/// result type, instead. We also insert a type cast to make sure that the -/// resulting IR is still well-typed. -struct StaticDynamicTensorFromElements - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements, - PatternRewriter &rewriter) const final { - auto resultType = - tensorFromElements.getResult().getType().cast(); - - if (resultType.hasStaticShape()) - return failure(); - - SmallVector newOperands; - SmallVector newShape; - auto operandsIt = tensorFromElements.dynamicExtents().begin(); - - for (int64_t dim : resultType.getShape()) { - if (dim != RankedTensorType::kDynamicSize) { - newShape.push_back(dim); - continue; - } - APInt index; - if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { - newShape.push_back(RankedTensorType::kDynamicSize); - newOperands.push_back(*operandsIt++); - continue; - } - newShape.push_back(index.getSExtValue()); - operandsIt++; - } - - if (newOperands.size() == tensorFromElements.dynamicExtents().size()) - return failure(); - - auto loc = tensorFromElements.getLoc(); - auto newOp = rewriter.create( - loc, RankedTensorType::get(newShape, resultType.getElementType()), - newOperands); - rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), - newOp.body().begin()); - rewriter.replaceOpWithNewOp(tensorFromElements, resultType, - newOp); - return success(); - } -}; - -/// Canonicalizes the pattern of the form -/// -/// %tensor = dynamic_tensor_from_elements %x { -/// ^bb0(%arg0: index): // no predecessors -/// -/// yield %1 : index -/// } : tensor -/// %extracted_element = tensor.extract %tensor[%c0] : tensor -/// -/// to just with %arg0 replaced by %c0. We only do this if the -/// dynamic_tensor_from_elements operation has no side-effects. -struct ExtractFromDynamicTensorFromElements - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorFromElements = - extract.tensor().getDefiningOp(); - if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) - return failure(); - - BlockAndValueMapping mapping; - Block *body = tensorFromElements.getBody(); - mapping.map(body->getArguments(), extract.indices()); - for (auto &op : body->without_terminator()) - rewriter.clone(op, mapping); - - auto yield = cast(body->getTerminator()); - - rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); - return success(); - } -}; - -/// Canonicalizes the pattern of the form -/// -/// %val = tensor.cast %source : : tensor to tensor<2xi32> -/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> -/// -/// to -/// -/// %extracted_element = tensor.extract %source[%c0] : tensor -struct ExtractFromTensorCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorCast = extract.tensor().getDefiningOp(); - if (!tensorCast) - return failure(); - - rewriter.replaceOpWithNewOp(extract, tensorCast.source(), - extract.indices()); - return success(); - } -}; - -} // namespace - -void DynamicTensorFromElementsOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - // TODO: Move extract patterns to tensor::ExtractOp. - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// TensorFromElementsOp -//===----------------------------------------------------------------------===// - -void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, - Type elementType, ValueRange elements) { - Type resultTy = RankedTensorType::get({static_cast(elements.size())}, - elementType); - result.addOperands(elements); - result.addTypes(resultTy); -} - -void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, - ValueRange elements) { - assert(!elements.empty() && "expected at least one element"); - build(builder, result, elements.front().getType(), elements); -} - -namespace { - -// Canonicalizes the pattern of the form -// -// %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32> -// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> -// -// to just %element. -struct ExtractElementFromTensorFromElements - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - if (extract.indices().size() != 1) - return failure(); - - auto tensorFromElements = dyn_cast_or_null( - extract.tensor().getDefiningOp()); - if (tensorFromElements == nullptr) - return failure(); - - APInt index; - if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) - return failure(); - rewriter.replaceOp(extract, - tensorFromElements.getOperand(index.getZExtValue())); - return success(); - } -}; - -} // namespace - -void TensorFromElementsOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -35,70 +35,6 @@ }; } // namespace -namespace { -class BufferizeDynamicTensorFromElementsOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(DynamicTensorFromElementsOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // Allocate memory. - Location loc = op.getLoc(); - DynamicTensorFromElementsOp::Adaptor transformed(operands); - RankedTensorType tensorType = op.getType().cast(); - MemRefType memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value result = - rewriter.create(loc, memrefType, transformed.dynamicExtents()); - - // Collect loop bounds. - int64_t rank = tensorType.getRank(); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - SmallVector lowerBounds(rank, zero); - SmallVector steps(rank, one); - SmallVector upperBounds; - int nextDynamicIndex = 0; - for (int i = 0; i < rank; i++) { - Value upperBound = - tensorType.isDynamicDim(i) - ? transformed.dynamicExtents()[nextDynamicIndex++] - : rewriter.create(loc, memrefType.getDimSize(i)); - upperBounds.push_back(upperBound); - } - - // Generate tensor elements with a parallel loop that stores into - // each element of the resulting memref. - // - // This is a bit tricky. We cannot simply clone the ops because when an op - // is cloned, it must be legalized. However, we want to allow arbitrary ops - // in the body that we don't necessarily have legalization patterns for as - // part of this dialect conversion invocation. - // - // To accomplish this, we use mergeBlockBefore to "move" this op's body - // into the scf.parallel's body. - auto parallel = - rewriter.create(loc, lowerBounds, upperBounds, steps); - Block *parallelBody = parallel.getBody(); - rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), - parallelBody->getArguments()); - // Replace the inlined yield op with a store op. The scf.parallel's builder - // already populated an scf.yield at the end, so we don't need to worry - // about creating that. - Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); - rewriter.setInsertionPointAfter(elementYield); - rewriter.replaceOpWithNewOp(elementYield, - elementYield->getOperands()[0], result, - parallelBody->getArguments()); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; -} // namespace - namespace { class BufferizeSelectOp : public OpConversionPattern { public: @@ -117,40 +53,10 @@ }; } // namespace -namespace { -class BufferizeTensorFromElementsOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(TensorFromElementsOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - int numberOfElements = op.elements().size(); - auto resultType = MemRefType::get( - {numberOfElements}, op.getType().cast().getElementType()); - Value result = rewriter.create(op.getLoc(), resultType); - for (auto element : llvm::enumerate(op.elements())) { - Value index = - rewriter.create(op.getLoc(), element.index()); - rewriter.create(op.getLoc(), element.value(), result, index); - } - rewriter.replaceOp(op, {result}); - return success(); - } -}; -} // namespace - void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert< - // clang-format off - BufferizeDimOp, - BufferizeDynamicTensorFromElementsOp, - BufferizeSelectOp, - BufferizeTensorFromElementsOp - // clang-format on - >(typeConverter, context); + patterns.insert(typeConverter, context); } namespace { @@ -165,7 +71,6 @@ target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp(); // We only bufferize the case of tensor selected type and scalar condition, // as that boils down to a select over memref descriptors (don't need to // touch the data). diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -13,5 +13,6 @@ LINK_LIBS PUBLIC MLIRIR + MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -205,6 +207,223 @@ return {}; } +//===----------------------------------------------------------------------===// +// FromElementsOp +//===----------------------------------------------------------------------===// + +void FromElementsOp::build(OpBuilder &builder, OperationState &result, + Type elementType, ValueRange elements) { + Type resultTy = RankedTensorType::get({static_cast(elements.size())}, + elementType); + result.addOperands(elements); + result.addTypes(resultTy); +} + +void FromElementsOp::build(OpBuilder &builder, OperationState &result, + ValueRange elements) { + assert(!elements.empty() && "expected at least one element"); + build(builder, result, elements.front().getType(), elements); +} + +namespace { + +// Canonicalizes the pattern of the form +// +// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> +// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> +// +// to just %element. +struct ExtractElementFromTensorFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 1) + return failure(); + + auto tensorFromElements = extract.tensor().getDefiningOp(); + if (tensorFromElements == nullptr) + return failure(); + + APInt index; + if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) + return failure(); + rewriter.replaceOp(extract, + tensorFromElements.getOperand(index.getZExtValue())); + return success(); + } +}; + +} // namespace + +void FromElementsOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// GenerateOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(GenerateOp op) { + // Ensure that the tensor type has as many dynamic dimensions as are specified + // by the operands. + RankedTensorType resultTy = op.getType().cast(); + if (op.getNumOperands() != resultTy.getNumDynamicDims()) + return op.emitError("must have as many index operands as dynamic extents " + "in the result type"); + + // Ensure that region arguments span the index space. + if (!llvm::all_of(op.body().getArgumentTypes(), + [](Type ty) { return ty.isIndex(); })) + return op.emitError("all body arguments must be index"); + if (op.body().getNumArguments() != resultTy.getRank()) + return op.emitError("must have one body argument per input dimension"); + + // Ensure that the region yields an element of the right type. + auto yieldOp = + llvm::cast(op.body().getBlocks().front().getTerminator()); + if (yieldOp.value().getType() != resultTy.getElementType()) + return op.emitOpError( + "body must be terminated with a `yield` operation of the tensor " + "element type"); + + return success(); +} + +void GenerateOp::build( + OpBuilder &b, OperationState &result, Type resultTy, + ValueRange dynamicExtents, + function_ref bodyBuilder) { + build(b, result, resultTy, dynamicExtents); + + // Build and populate body. + OpBuilder::InsertionGuard guard(b); + Region *bodyRegion = result.regions.front().get(); + auto rank = resultTy.cast().getRank(); + SmallVector argumentTypes(rank, b.getIndexType()); + Block *bodyBlock = + b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); + bodyBuilder(b, result.location, bodyBlock->getArguments()); +} + +namespace { + +/// Canonicalizes tensor.generate operations with a constant +/// operand into the equivalent operation with the operand expressed in the +/// result type, instead. We also insert a type cast to make sure that the +/// resulting IR is still well-typed. +struct StaticTensorGenerate : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenerateOp tensorFromElements, + PatternRewriter &rewriter) const final { + auto resultType = + tensorFromElements.getResult().getType().cast(); + + if (resultType.hasStaticShape()) + return failure(); + + SmallVector newOperands; + SmallVector newShape; + auto operandsIt = tensorFromElements.dynamicExtents().begin(); + + for (int64_t dim : resultType.getShape()) { + if (dim != RankedTensorType::kDynamicSize) { + newShape.push_back(dim); + continue; + } + APInt index; + if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { + newShape.push_back(RankedTensorType::kDynamicSize); + newOperands.push_back(*operandsIt++); + continue; + } + newShape.push_back(index.getSExtValue()); + operandsIt++; + } + + if (newOperands.size() == tensorFromElements.dynamicExtents().size()) + return failure(); + + auto loc = tensorFromElements.getLoc(); + auto newOp = rewriter.create( + loc, RankedTensorType::get(newShape, resultType.getElementType()), + newOperands); + rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), + newOp.body().begin()); + rewriter.replaceOpWithNewOp(tensorFromElements, resultType, + newOp); + return success(); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %tensor = tensor.generate %x { +/// ^bb0(%arg0: index): // no predecessors +/// +/// yield %1 : index +/// } : tensor +/// %extracted_element = tensor.extract %tensor[%c0] : tensor +/// +/// to just with %arg0 replaced by %c0. We only do this if the +/// tensor.generate operation has no side-effects. +struct ExtractFromTensorGenerate : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + auto tensorFromElements = extract.tensor().getDefiningOp(); + if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) + return failure(); + + BlockAndValueMapping mapping; + Block *body = tensorFromElements.getBody(); + mapping.map(body->getArguments(), extract.indices()); + for (auto &op : body->without_terminator()) + rewriter.clone(op, mapping); + + auto yield = cast(body->getTerminator()); + + rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); + return success(); + } +}; + +/// Canonicalizes the pattern of the form +/// +/// %val = tensor.cast %source : : tensor to tensor<2xi32> +/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> +/// +/// to +/// +/// %extracted_element = tensor.extract %source[%c0] : tensor +struct ExtractFromTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + auto tensorCast = extract.tensor().getDefiningOp(); + if (!tensorCast) + return failure(); + + rewriter.replaceOpWithNewOp(extract, tensorCast.source(), + extract.indices()); + return success(); + } +}; + +} // namespace + +void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + // TODO: Move extract patterns to tensor::ExtractOp. + results.insert(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" @@ -48,10 +49,97 @@ }; } // namespace +namespace { +class BufferizeFromElementsOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::FromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + int numberOfElements = op.elements().size(); + auto resultType = MemRefType::get( + {numberOfElements}, op.getType().cast().getElementType()); + Value result = rewriter.create(op.getLoc(), resultType); + for (auto element : llvm::enumerate(op.elements())) { + Value index = + rewriter.create(op.getLoc(), element.index()); + rewriter.create(op.getLoc(), element.value(), result, index); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + +namespace { +class BufferizeGenerateOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::GenerateOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Allocate memory. + Location loc = op.getLoc(); + tensor::GenerateOp::Adaptor transformed(operands); + RankedTensorType tensorType = op.getType().cast(); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value result = + rewriter.create(loc, memrefType, transformed.dynamicExtents()); + + // Collect loop bounds. + int64_t rank = tensorType.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lowerBounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upperBounds; + int nextDynamicIndex = 0; + for (int i = 0; i < rank; i++) { + Value upperBound = + tensorType.isDynamicDim(i) + ? transformed.dynamicExtents()[nextDynamicIndex++] + : rewriter.create(loc, memrefType.getDimSize(i)); + upperBounds.push_back(upperBound); + } + + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. + // + // This is a bit tricky. We cannot simply clone the ops because when an op + // is cloned, it must be legalized. However, we want to allow arbitrary ops + // in the body that we don't necessarily have legalization patterns for as + // part of this dialect conversion invocation. + // + // To accomplish this, we use mergeBlockBefore to "move" this op's body + // into the scf.parallel's body. + auto parallel = + rewriter.create(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp(elementYield, + elementYield->getOperands()[0], result, + parallelBody->getArguments()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + void mlir::populateTensorBufferizePatterns( MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); } namespace { @@ -62,9 +150,13 @@ OwningRewritePatternList patterns; ConversionTarget target(*context); + populateBufferizeMaterializationLegality(target); + populateTensorBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalDialect(); + target.addLegalDialect(); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRSCF MLIRTensor MLIRTransforms ) diff --git a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Tensor/Transforms/PassDetail.h @@ -13,6 +13,10 @@ namespace mlir { +namespace scf { +class SCFDialect; +} // end namespace scf + #define GEN_PASS_CLASSES #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -87,14 +87,14 @@ // ----- -// Lower `const_shape` to `tensor_from_elements`. +// Lower `const_shape` to `tensor.from_elements`. // CHECK-LABEL: @const_shape // CHECK-SAME: () -> tensor func @const_shape() -> tensor { // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C3:.*]] = constant 3 : index - // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] + // CHECK: %[[TENSOR3:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor // CHECK: return %[[RESULT]] : tensor %shape = shape.const_shape [1, 2, 3] : tensor @@ -107,7 +107,7 @@ // CHECK-LABEL: func @const_shape_zero_elements // CHECK-SAME: () -> tensor func @const_shape_zero_elements() -> tensor { - // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex> + // CHECK: %[[TENSOR:.*]] = tensor.from_elements : tensor<0xindex> // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor // CHECK: return %[[RESULT]] : tensor %shape = shape.const_shape [] : tensor @@ -204,7 +204,7 @@ // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> - // CHECK: %[[SHAPE:.*]] = dynamic_tensor_from_elements %[[RANK]] { + // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] { // CHECK: ^bb0(%[[I:.*]]: index): // CHECK: %[[EXTENT:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> // CHECK: yield %[[EXTENT]] : index @@ -233,7 +233,7 @@ // CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor return } @@ -244,7 +244,7 @@ // CHECK-LABEL: @shape_of_zero_d // CHECK-SAME: (%[[ARG:.*]]: tensor) func @shape_of_zero_d(%arg : tensor) { - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements : tensor<0xindex> %shape = shape.shape_of %arg : tensor -> tensor return } @@ -259,7 +259,7 @@ // CHECK-DAG: %[[C5:.*]] = constant 5 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor.from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor return } @@ -321,7 +321,7 @@ // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { + // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] { // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor @@ -361,7 +361,7 @@ // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { + // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] { // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -11,56 +11,6 @@ return %0 : index } -// CHECK-LABEL: func @dynamic_tensor_from_elements( -// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, -// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { -// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32> -// CHECK: %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32> -// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref -// CHECK: scf.yield -// CHECK: } -// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref -// CHECK: return %[[RET]] : tensor -// CHECK: } -func @dynamic_tensor_from_elements(%arg: tensor<*xf32>, %rank: index) -> tensor { - %result = dynamic_tensor_from_elements %rank { - ^bb0(%i : index): - %elem = dim %arg, %i : tensor<*xf32> - yield %elem : index - } : tensor - return %result : tensor -} - -// Additional test that checks the logic for intermixed static and dynamic -// extents. -// -// CHECK-LABEL: func @dynamic_tensor_from_elements_static_and_dynamic( -// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { -// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[C16:.*]] = constant 16 : index -// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { -// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index -// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> -// CHECK: scf.yield -// CHECK: } -// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex> -// CHECK: return %[[RET]] : tensor<16x?xindex> -// CHECK: } -func @dynamic_tensor_from_elements_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { - %result = dynamic_tensor_from_elements %arg0 { - ^bb0(%i: index, %j: index): - %sum = addi %i, %j : index - yield %sum : index - } : tensor<16x?xindex> - return %result : tensor<16x?xindex> -} - // CHECK-LABEL: func @select( // CHECK-SAME: %[[PRED:.*]]: i1, // CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, @@ -74,36 +24,3 @@ %0 = select %arg0, %arg1, %arg2 : tensor return %0 : tensor } - -// CHECK-LABEL: func @tensor_from_elements( -// CHECK-SAME: %[[ELEM0:.*]]: index, -// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { -// CHECK: %[[MEMREF:.*]] = alloc() -// CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] -// CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] -// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] -// CHECK: return %[[RET]] : tensor<2xindex> -func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { - %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex> - return %0 : tensor<2xindex> -} - -// The dynamic_tensor_from_elements op needs to put its body into the -// resulting scf.parallel. To handle unknown ops in the body, it cannot clone -// the body because that would require the cloned ops to be legalized -// immediately, which is usually not possible since they might be from various -// other dialects. -// -// CHECK-LABEL: func @unknown_ops_in_body -func @unknown_ops_in_body(%arg0: index) -> tensor { - // CHECK-NOT: dynamic_tensor_from_elements - %tensor = dynamic_tensor_from_elements %arg0 { - ^bb0(%iv: index): - // CHECK: test.source - %0 = "test.source"() : () -> index - yield %0 : index - } : tensor - return %tensor : tensor -} diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -59,16 +59,16 @@ return %1 : f32 } -// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx -// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( +// Test case: Folding of dim(tensor.generate %idx) -> %idx +// CHECK-LABEL: func @dim_of_tensor.generate( // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index // CHECK-NOT: dim // CHECK: return %[[IDX1]] : index -func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index { +func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index { %c3 = constant 3 : index - %0 = dynamic_tensor_from_elements %arg0, %arg1 { + %0 = tensor.generate %arg0, %arg1 { ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): - yield %c3 : index + tensor.yield %c3 : index } : tensor<2x?x4x?x5xindex> %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> return %1 : index diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -16,72 +16,6 @@ // ----- -func @dynamic_tensor_from_elements(%m : index) - -> tensor { - // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}} - %tnsr = dynamic_tensor_from_elements %m { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor - return %tnsr : tensor -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor { - // expected-error @+1 {{must have one body argument per input dimension}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor - return %tnsr : tensor -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor { - // expected-error @+1 {{all body arguments must be index}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : i64): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor - return %tnsr : tensor -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor { - // expected-error @+2 {{op expects regions to end with 'std.yield', found 'std.return'}} - // expected-note @+1 {{in custom textual format, the absence of terminator implies 'std.yield'}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8.0 : f32 - return %elem : f32 - } : tensor - return %tnsr : tensor -} - -// ----- - -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor { - // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}} - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8 : i32 - yield %elem : i32 - } : tensor - return %tnsr : tensor -} - -// ----- - func @transpose_not_permutation(%v : memref(off + M * i + j)>>) { // expected-error @+1 {{expected a permutation map}} transpose %v (i, j) -> (i, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -32,17 +32,6 @@ return } -// CHECK-LABEL: @dynamic_tensor_from_elements -func @dynamic_tensor_from_elements(%m : index, %n : index) - -> tensor { - %tnsr = dynamic_tensor_from_elements %m, %n { - ^bb0(%i : index, %j : index, %k : index): - %elem = constant 8.0 : f32 - yield %elem : f32 - } : tensor - return %tnsr : tensor -} - // CHECK-LABEL: @atan func @atan(%arg : f32) -> f32 { %result = atan %arg : f32 @@ -107,4 +96,3 @@ %1 = tensor_load %0 : memref<2xf32> return } - diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -33,14 +33,96 @@ return %0 : tensor<*xf32> } -// CHECK-LABEL: func @extract( +// CHECK-LABEL: func @tensor.extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[IDX:.*]]: index) -> f32 { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref // CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref // CHECK: return %[[RET]] : f32 // CHECK: } -func @extract(%arg0: tensor, %arg1: index) -> f32 { +func @tensor.extract(%arg0: tensor, %arg1: index) -> f32 { %0 = tensor.extract %arg0[%arg1] : tensor return %0 : f32 } + +// CHECK-LABEL: func @tensor.from_elements( +// CHECK-SAME: %[[ELEM0:.*]]: index, +// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { +// CHECK: %[[MEMREF:.*]] = alloc() +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] +// CHECK: return %[[RET]] : tensor<2xindex> +func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { + %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> + return %0 : tensor<2xindex> +} + +// CHECK-LABEL: func @tensor.generate( +// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { +// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { +// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> +// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref +// CHECK: return %[[RET]] : tensor +// CHECK: } +func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor { + %result = tensor.generate %dynamic_extent { + ^bb0(%i : index): + %elem = dim %arg, %i : tensor<*xf32> + tensor.yield %elem : index + } : tensor + return %result : tensor +} + +// Additional test that checks the logic for intermixed static and dynamic +// extents. +// +// CHECK-LABEL: func @tensor.generate_static_and_dynamic( +// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> { +// CHECK: %[[MEMREF:.*]] = alloc(%[[DYNAMIC_EXTENT]]) : memref<16x?xindex> +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C16:.*]] = constant 16 : index +// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) { +// CHECK: %[[VAL_7:.*]] = addi %[[I]], %[[J]] : index +// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex> +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] : memref<16x?xindex> +// CHECK: return %[[RET]] : tensor<16x?xindex> +// CHECK: } +func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { + %result = tensor.generate %arg0 { + ^bb0(%i: index, %j: index): + %sum = addi %i, %j : index + tensor.yield %sum : index + } : tensor<16x?xindex> + return %result : tensor<16x?xindex> +} + +// The tensor.generate op needs to put its body into the +// resulting scf.parallel. To handle unknown ops in the body, it cannot clone +// the body because that would require the cloned ops to be legalized +// immediately, which is usually not possible since they might be from various +// other dialects. +// +// CHECK-LABEL: func @tensor.generate_unknown_ops_in_body +func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor { + // CHECK-NOT: tensor.generate + %tensor = tensor.generate %arg0 { + ^bb0(%iv: index): + // CHECK: test.source + %0 = "test.source"() : () -> index + tensor.yield %0 : index + } : tensor + return %tensor : tensor +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -107,3 +107,90 @@ %result = tensor.extract %casted[%c0] : tensor return %result : f32 } + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.from_elements +func @extract_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c0 = constant 0 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> + // CHECK: [[ARG]] : index + return %extracted_element : index +} + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.generate +// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> +func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]] + %0 = tensor.generate %size { + ^bb0(%arg0: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + tensor.yield %1 : index + } : tensor + %1 = tensor.extract %0[%idx] : tensor + // CHECK-NEXT: return %[[RES]] + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.generate_2d +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> +func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]] + // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]] + // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]] + %0 = tensor.generate %size, %size { + ^bb0(%arg0: index, %arg1: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + %2 = dim %tensor, %arg1 : tensor<*xf32> + %3 = addi %1, %2 : index + tensor.yield %3 : index + } : tensor + %4 = tensor.extract %0[%idx0, %idx1] : tensor + // CHECK-NEXT: return %[[RES]] + return %4 : index +} + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.generate_sideeffects +// CHECK-SAME: %[[IDX:.*]]: index +func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index { + %size = rank %tensor : tensor<*xf32> + %mem = alloc(%size) : memref + // CHECK: %[[DTENSOR:.*]] = tensor.generate + %0 = tensor.generate %size { + ^bb0(%arg0: index): + %1 = dim %tensor, %arg0 : tensor<*xf32> + store %1, %mem[%arg0] : memref + tensor.yield %1 : index + } : tensor + // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]] + %1 = tensor.extract %0[%idx] : tensor + // CHECK-NEXT: return %[[RES]] + return %1 : index +} + +// ----- + +// CHECK-LABEL: @static_tensor.generate +// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) +func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { + %c5 = constant 5 : index + // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]] + %0 = tensor.generate %size1, %c5, %size4 { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): + %1 = constant 32 : index + tensor.yield %1 : index + // CHECK: : tensor<3x?x5x7x?xindex> + } : tensor<3x?x?x7x?xindex> + // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> + return %0 : tensor<3x?x?x7x?xindex> +} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -13,3 +13,87 @@ %0 = tensor.extract %arg0[] : tensor return } + +// ----- + +func @tensor.from_elements_wrong_result_type() { + // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} + %c0 = constant 0 : i32 + %0 = tensor.from_elements %c0 : tensor<*xi32> + return +} + +// ----- + +func @tensor.from_elements_wrong_elements_count() { + // expected-error@+2 {{1 operands present, but expected 2}} + %c0 = constant 0 : index + %0 = tensor.from_elements %c0 : tensor<2xindex> + return +} + +// ----- + +func @tensor.generate(%m : index) + -> tensor { + // expected-error @+1 {{must have as many index operands as dynamic extents in the result type}} + %tnsr = tensor.generate %m { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{must have one body argument per input dimension}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{all body arguments must be index}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : i64): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor { + // expected-error @+2 {{op expects regions to end with 'tensor.yield', found 'std.return'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tensor.yield'}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + return %elem : f32 + } : tensor + return %tnsr : tensor +} + +// ----- + +func @tensor.generate(%m : index, %n : index) + -> tensor { + // expected-error @+1 {{body must be terminated with a `yield` operation of the tensor element type}} + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8 : i32 + tensor.yield %elem : i32 + } : tensor + return %tnsr : tensor +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -21,3 +21,35 @@ %0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor return } + +// CHECK-LABEL: func @tensor.from_elements() { +func @tensor.from_elements() { + %c0 = "std.constant"() {value = 0: index} : () -> index + // CHECK: %0 = tensor.from_elements %c0 : tensor<1xindex> + %0 = tensor.from_elements %c0 : tensor<1xindex> + + %c1 = "std.constant"() {value = 1: index} : () -> index + // CHECK: %1 = tensor.from_elements %c0, %c1 : tensor<2xindex> + %1 = tensor.from_elements %c0, %c1 : tensor<2xindex> + + %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32 + // CHECK: [[C0_F32:%.*]] = constant + // CHECK: %2 = tensor.from_elements [[C0_F32]] : tensor<1xf32> + %2 = tensor.from_elements %c0_f32 : tensor<1xf32> + + // CHECK: tensor.from_elements : tensor<0xindex> + %3 = tensor.from_elements : tensor<0xindex> + + return +} + +// CHECK-LABEL: @tensor.generate +func @tensor.generate(%m : index, %n : index) + -> tensor { + %tnsr = tensor.generate %m, %n { + ^bb0(%i : index, %j : index, %k : index): + %elem = constant 8.0 : f32 + tensor.yield %elem : f32 + } : tensor + return %tnsr : tensor +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -675,27 +675,6 @@ return } -// CHECK-LABEL: func @tensor_from_elements() { -func @tensor_from_elements() { - %c0 = "std.constant"() {value = 0: index} : () -> index - // CHECK: %0 = tensor_from_elements %c0 : tensor<1xindex> - %0 = tensor_from_elements %c0 : tensor<1xindex> - - %c1 = "std.constant"() {value = 1: index} : () -> index - // CHECK: %1 = tensor_from_elements %c0, %c1 : tensor<2xindex> - %1 = tensor_from_elements %c0, %c1 : tensor<2xindex> - - %c0_f32 = "std.constant"() {value = 0.0: f32} : () -> f32 - // CHECK: [[C0_F32:%.*]] = constant - // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32> - %2 = tensor_from_elements %c0_f32 : tensor<1xf32> - - // CHECK: tensor_from_elements : tensor<0xindex> - %3 = tensor_from_elements : tensor<0xindex> - - return -} - // CHECK-LABEL: func @memref_cast(%arg0 func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) { // CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -541,24 +541,6 @@ // ----- -func @tensor_from_elements_wrong_result_type() { - // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} - %c0 = constant 0 : i32 - %0 = tensor_from_elements %c0 : tensor<*xi32> - return -} - -// ----- - -func @tensor_from_elements_wrong_elements_count() { - // expected-error@+2 {{1 operands present, but expected 2}} - %c0 = constant 0 : index - %0 = tensor_from_elements %c0 : tensor<2xindex> - return -} - -// ----- - func @index_cast_index_to_index(%arg0: index) { // expected-error@+1 {{are cast incompatible}} %0 = index_cast %arg0: index to index diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1032,93 +1032,6 @@ // ----- -// CHECK-LABEL: func @extract_from_tensor_from_elements -func @extract_from_tensor_from_elements(%element : index) -> index { - // CHECK-SAME: ([[ARG:%.*]]: index) - %c0 = constant 0 : index - %tensor = tensor_from_elements %element : tensor<1xindex> - %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> - // CHECK: [[ARG]] : index - return %extracted_element : index -} - -// ----- - -// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements -// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> -func @extract_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> - // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]] - %0 = dynamic_tensor_from_elements %size { - ^bb0(%arg0: index): - %1 = dim %tensor, %arg0 : tensor<*xf32> - yield %1 : index - } : tensor - %1 = tensor.extract %0[%idx] : tensor - // CHECK-NEXT: return %[[RES]] - return %1 : index -} - -// ----- - -// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_2d -// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> -func @extract_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> - // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]] - // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]] - // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]] - %0 = dynamic_tensor_from_elements %size, %size { - ^bb0(%arg0: index, %arg1: index): - %1 = dim %tensor, %arg0 : tensor<*xf32> - %2 = dim %tensor, %arg1 : tensor<*xf32> - %3 = addi %1, %2 : index - yield %3 : index - } : tensor - %4 = tensor.extract %0[%idx0, %idx1] : tensor - // CHECK-NEXT: return %[[RES]] - return %4 : index -} - -// ----- - -// CHECK-LABEL: func @extract_from_dynamic_tensor_from_elements_sideeffects -// CHECK-SAME: %[[IDX:.*]]: index -func @extract_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index { - %size = rank %tensor : tensor<*xf32> - %mem = alloc(%size) : memref - // CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements - %0 = dynamic_tensor_from_elements %size { - ^bb0(%arg0: index): - %1 = dim %tensor, %arg0 : tensor<*xf32> - store %1, %mem[%arg0] : memref - yield %1 : index - } : tensor - // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]] - %1 = tensor.extract %0[%idx] : tensor - // CHECK-NEXT: return %[[RES]] - return %1 : index -} - -// ----- - -// CHECK-LABEL: @static_dynamic_tensor_from_elements -// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) -func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { - %c5 = constant 5 : index - // CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]] - %0 = dynamic_tensor_from_elements %size1, %c5, %size4 { - ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): - %1 = constant 32 : index - yield %1 : index - // CHECK: : tensor<3x?x5x7x?xindex> - } : tensor<3x?x?x7x?xindex> - // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> - return %0 : tensor<3x?x?x7x?xindex> -} - -// ----- - // CHECK-LABEL: func @subtensor // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)