diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -34,7 +34,7 @@ class LinalgDependenceGraph; /// A struct containing the Linalg producer before and after fusion. -/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op +/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op /// before the consumer Linalg op, until enough canonicalizations have applied. struct FusionInfo { LinalgOp originalProducer; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -354,31 +354,6 @@ /// ``` bool canFoldIntoConsumerOp(MemRefCastOp castOp); -/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. -/// Determines whether TensorCastOp casts to a more dynamic version of the -/// source tensor. This is useful to fold a tensor_cast into a consuming op and -/// implement canonicalization patterns for ops in different dialects that may -/// consume the results of tensor_cast operations. Such foldable tensor_cast -/// operations are typically inserted as `subtensor` ops and are canonicalized, -/// to preserve the type compatibility of their uses. -/// -/// Returns true when all conditions are met: -/// 1. source and result are ranked tensors with same element type and rank. -/// 2. the tensor type has more static information than the result -/// -/// Example: -/// ```mlir -/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor -/// %2 = consumer %1 ... : tensor ... -/// ``` -/// -/// folds into: -/// -/// ```mlir -/// %2 = consumer %0 ... : tensor<8x16xf32> ... -/// ``` -bool canFoldIntoConsumerOp(TensorCastOp castOp); - /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -62,7 +62,7 @@ let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; - let verifier = [{ return ::verifyCastOp(*this); }]; + let verifier = [{ return impl::verifyCastOp(*this, areCastCompatible); }]; let hasFolder = 1; } @@ -3428,56 +3428,6 @@ }]; } -//===----------------------------------------------------------------------===// -// TensorCastOp -//===----------------------------------------------------------------------===// - -def TensorCastOp : CastOp<"tensor_cast"> { - let summary = "tensor cast operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type - ``` - - Convert a tensor from one type to an equivalent type without changing any - data elements. The source and destination types must both be tensor types - with the same element type. If both are ranked, then the rank should be the - same and static dimensions should match. The operation is invalid if - converting to a mismatching constant dimension. - - Example: - - ```mlir - // Convert from unknown rank to rank 2 with unknown dimension sizes. - %2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor - %2 = tensor_cast %1 : tensor<*xf32> to tensor - - // Convert to a type with more known dimensions. - %3 = "std.tensor_cast"(%2) : (tensor) -> tensor<4x?xf32> - - // Discard static dimension and rank information. - %4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor - %5 = "std.tensor_cast"(%4) : (tensor) -> tensor<*xf32> - ``` - }]; - - let arguments = (ins AnyTensor:$source); - let results = (outs AnyTensor); - - let extraClassDeclaration = [{ - /// Return true if `a` and `b` are valid operand and result pairs for - /// the operation. - static bool areCastCompatible(Type a, Type b); - - /// The result of a tensor_cast is always a tensor. - TensorType getType() { return getResult().getType().cast(); } - }]; - - let hasCanonicalizer = 1; -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -28,4 +28,38 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc" +//===----------------------------------------------------------------------===// +// Tensor Dialect Helpers +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace tensor { + +/// Determines whether tensor::CastOp casts to a more dynamic version of the +/// source tensor. This is useful to fold a tensor.cast into a consuming op and +/// implement canonicalization patterns for ops in different dialects that may +/// consume the results of tensor.cast operations. Such foldable tensor.cast +/// operations are typically inserted as `subtensor` ops and are canonicalized, +/// to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked tensors with same element type and rank. +/// 2. the tensor type has more static information than the result +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = consumer %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = consumer %0 ... : tensor<8x16xf32> ... +/// ``` +bool canFoldIntoConsumerOp(CastOp castOp); + +} // namespace tensor +} // namespace mlir + #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_ diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -19,6 +19,52 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +def Tensor_CastOp : Tensor_Op<"cast", [NoSideEffect]> { + let summary = "tensor cast operation"; + let description = [{ + Convert a tensor from one type to an equivalent type without changing any + data elements. The source and destination types must both be tensor types + with the same element type. If both are ranked, then the rank should be the + same and static dimensions should match. The operation is invalid if + converting to a mismatching constant dimension. + + Example: + + ```mlir + // Convert from unknown rank to rank 2 with unknown dimension sizes. + %2 = tensor.cast %1 : tensor<*xf32> to tensor + + // Convert to a type with more known dimensions. + %3 = tensor.cast %2 : tensor to tensor<4x?xf32> + + // Discard static dimension and rank information. + %4 = tensor.cast %3 : tensor<4x?xf32> to tensor + %5 = tensor.cast %4 : tensor to tensor<*xf32> + ``` + }]; + + let arguments = (ins AnyTensor:$source); + let results = (outs AnyTensor:$dest); + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + let verifier = "return impl::verifyCastOp(*this, areCastCompatible);"; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a tensor.cast is always a tensor. + TensorType getType() { return getResult().getType().cast(); } + }]; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1775,11 +1775,18 @@ // These functions are out-of-line implementations of the methods in CastOp, // which avoids them being template instantiated/duplicated. namespace impl { +// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the +// need for them, but some older ODS code in `std` still depends on them). void buildCastOp(OpBuilder &builder, OperationState &result, Value source, Type destType); ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); void printCastOp(Operation *op, OpAsmPrinter &p); +// TODO: Create a CastOpInterface with a method areCastCompatible. +// Also, consider adding functionality to CastOpInterface to be able to perform +// the ChainedTensorCast canonicalization generically. Value foldCastOp(Operation *op); +LogicalResult verifyCastOp(Operation *op, + function_ref areCastCompatible); } // namespace impl } // end namespace mlir diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -tensor-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s @@ -8,7 +8,7 @@ %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32> %addf = addf %a, %b : tensor<3xf32> - %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32> + %addf_unranked = tensor.cast %addf : tensor<3xf32> to tensor<*xf32> call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> () // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data = // CHECK-NEXT: [11, 22, 33] diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize \ // RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ @@ -15,14 +17,14 @@ %inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32> %inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32> - %unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32> + %unranked_at_position_0 = tensor.cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32> call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} // CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data = // CHECK-NEXT: [20, 10] - %unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32> + %unranked_at_position_1 = tensor.cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32> call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \ +// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \ +// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \ +// RUN: -finalizing-bufferize \ // RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ @@ -9,7 +11,7 @@ %insert_val = constant dense<20.0> : tensor<1xf32> %inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32> - %unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32> + %unranked = tensor.cast %inserted : tensor<2xf32> to tensor<*xf32> call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize \ -// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \ +// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ @@ -19,7 +19,7 @@ // Note that this is skipping a step and we would need at least some function // attribute to declare that this conversion is valid (e.g. when we statically // know that things will play nicely at the C ABI boundary). - %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32> + %unranked = tensor.cast %0 : tensor<4xf32> to tensor<*xf32> call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir @@ -1,12 +1,13 @@ // RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize \ -// RUN: -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \ +// RUN: -tensor-bufferize -func-bufferize -finalizing-bufferize -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \ -// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \ +// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -tensor-bufferize \ +// RUN: -func-bufferize \ // RUN: -finalizing-bufferize -convert-linalg-to-loops -convert-scf-to-std \ // RUN: -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ @@ -23,7 +24,7 @@ %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>) init(%C: tensor<2x4xf32>) -> tensor<2x4xf32> - %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32> + %unranked = tensor.cast %D : tensor<2x4xf32> to tensor<*xf32> call @print_memref_f32(%unranked) : (tensor<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -103,9 +103,9 @@ auto erasedRankType = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); Value rankErasedLhs = - rewriter.create(loc, erasedRankType, transformed.lhs()); + rewriter.create(loc, erasedRankType, transformed.lhs()); Value rankErasedRhs = - rewriter.create(loc, erasedRankType, transformed.rhs()); + rewriter.create(loc, erasedRankType, transformed.rhs()); Value lesserRankOperand = rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); Value greaterRankOperand = @@ -186,7 +186,7 @@ Value tensor = rewriter.create(loc, indexTy, extentOperands); Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); - rewriter.replaceOpWithNewOp(op, tensor, resultTy); + rewriter.replaceOpWithNewOp(op, resultTy, tensor); return success(); } @@ -246,9 +246,9 @@ auto erasedRankType = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); Value rankErasedLhs = - rewriter.create(loc, erasedRankType, transformed.lhs()); + rewriter.create(loc, erasedRankType, transformed.lhs()); Value rankErasedRhs = - rewriter.create(loc, erasedRankType, transformed.rhs()); + rewriter.create(loc, erasedRankType, transformed.rhs()); Value lesserRankOperand = rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); Value greaterRankOperand = @@ -528,8 +528,8 @@ // Materialize extent tensor. Value staticExtentTensor = rewriter.create( loc, rewriter.getIndexType(), extentValues); - rewriter.replaceOpWithNewOp(op, staticExtentTensor, - op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), + staticExtentTensor); return success(); } @@ -561,8 +561,8 @@ if (!adaptor.input().getType().isa()) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); - rewriter.replaceOpWithNewOp(op, adaptor.input(), - op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), + adaptor.input()); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1676,12 +1676,12 @@ if (!linalgOp) return failure(); - // If no operand comes from a TensorCastOp and can be folded then fail. + // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) { if (v.isa()) return false; - auto castOp = v.getDefiningOp(); + auto castOp = v.getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) @@ -1693,7 +1693,7 @@ newOperands.reserve(op->getNumOperands()); // Inputs may fold. for (Value v : linalgOp.getInputs()) { - auto tensorCastOp = v.getDefiningOp(); + auto tensorCastOp = v.getDefiningOp(); newOperands.push_back( canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v); } @@ -1702,7 +1702,7 @@ linalgOp.getOutputBuffers().end()); // Init tensors may fold, in which case the resultType must also change. for (Value v : linalgOp.getInitTensors()) { - auto tensorCastOp = v.getDefiningOp(); + auto tensorCastOp = v.getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : v); newResultTypes.push_back(newOperands.back().getType()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -36,6 +36,7 @@ MLIRStandard MLIRStandardOpsTransforms MLIRStandardToLLVM + MLIRTensor MLIRTransforms MLIRTransformUtils MLIRVector diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" @@ -517,13 +518,13 @@ // Replace use. // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type - // mismatches. Insert a `tensor_cast` op to propagate the transformation + // mismatches. Insert a `tensor.cast` op to propagate the transformation // invariant that types are compatible. Value def = fusedProducer->getResult(producerIdx); OpOperand &use = consumer->getOpOperand(consumerIdx); Type consumerType = use.get().getType(); if (consumerType != def.getType()) - def = b.create(fusedProducer.getLoc(), consumerType, def); + def = b.create(fusedProducer.getLoc(), consumerType, def); use.set(def); return FusionInfo{producerOp, fusedProducer}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -569,7 +570,7 @@ ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); SubTensorOp::getCanonicalizationPatterns(patterns, ctx); SubViewOp::getCanonicalizationPatterns(patterns, ctx); - TensorCastOp::getCanonicalizationPatterns(patterns, ctx); + tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); ViewOp::getCanonicalizationPatterns(patterns, ctx); CanonicalizationPatternList< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -18,4 +18,5 @@ MLIRIR MLIRSideEffectInterfaces MLIRStandard + MLIRTensor ) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -1,5 +1,6 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" def AllInputShapesEq : Constraint; -// Fold tensor_cast(const_shape) to const_shape. This changes the type of +// Fold tensor.cast(const_shape) to const_shape. This changes the type of // const_shape to the destination type of the cast. def TensorCastConstShape : Pat < - (TensorCastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>; + (Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -141,18 +141,6 @@ << op->getResult(0).getType(); } -/// A custom cast operation verifier. -template -static LogicalResult verifyCastOp(T op) { - auto opType = op.getOperand().getType(); - auto resType = op.getType(); - if (!T::areCastCompatible(opType, resType)) - return op.emitError("operand type ") << opType << " and result type " - << resType << " are cast incompatible"; - - return success(); -} - void StandardOpsDialect::initialize() { getContext()->loadDialect(); addOperations(tensorFromElements, resultType, - newOp); + rewriter.replaceOpWithNewOp(tensorFromElements, resultType, + newOp); return success(); } }; @@ -1895,7 +1883,7 @@ /// Canonicalizes the pattern of the form /// -/// %val = tensor_cast %source : : tensor to tensor<2xi32> +/// %val = tensor.cast %source : : tensor to tensor<2xi32> /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> /// /// to @@ -1906,7 +1894,7 @@ LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { - auto tensorCast = extract.tensor().getDefiningOp(); + auto tensorCast = extract.tensor().getDefiningOp(); if (!tensorCast) return failure(); @@ -3377,7 +3365,7 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op, SubTensorOp newOp) { - rewriter.replaceOpWithNewOp(op, newOp, op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); } /// Pattern to rewrite a subview op with constant arguments. @@ -3518,60 +3506,6 @@ return true; } -/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. -/// Determines whether TensorCastOp casts to a more dynamic version of the -/// source tensor. This is useful to fold a tensor_cast into a consuming op and -/// implement canonicalization patterns for ops in different dialects that may -/// consume the results of tensor_cast operations. Such foldable tensor_cast -/// operations are typically inserted as `subtensor` ops and are canonicalized, -/// to preserve the type compatibility of their uses. -/// -/// Returns true when all conditions are met: -/// 1. source and result are ranked tensors with same element type and rank. -/// 2. the tensor type has more static information than the result -/// -/// Example: -/// ```mlir -/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor -/// %2 = consumer %1 ... : tensor ... -/// ``` -/// -/// folds into: -/// -/// ```mlir -/// %2 = consumer %0 ... : tensor<8x16xf32> ... -/// ``` -bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) { - if (!castOp) - return false; - - RankedTensorType sourceType = - castOp.source().getType().dyn_cast(); - RankedTensorType resultType = castOp.getType().dyn_cast(); - - // Requires RankedTensorType. - if (!sourceType || !resultType) - return false; - - // Requires same elemental type. - if (sourceType.getElementType() != resultType.getElementType()) - return false; - - // Requires same rank. - if (sourceType.getRank() != resultType.getRank()) - return false; - - // If cast is towards more static sizes along any dimension, don't fold. - for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { - auto ss = std::get<0>(it), st = std::get<1>(it); - if (ss != st) - if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) - return false; - } - - return true; -} - namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref_cast past its consuming subview when @@ -3839,107 +3773,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// TensorCastOp -//===----------------------------------------------------------------------===// - -bool TensorCastOp::areCastCompatible(Type a, Type b) { - auto aT = a.dyn_cast(); - auto bT = b.dyn_cast(); - if (!aT || !bT) - return false; - - if (aT.getElementType() != bT.getElementType()) - return false; - - return succeeded(verifyCompatibleShape(aT, bT)); -} - -OpFoldResult TensorCastOp::fold(ArrayRef operands) { - return impl::foldCastOp(*this); -} - -/// Compute a TensorType that has the joined shape knowledge of the two -/// given TensorTypes. The element types need to match. -static TensorType joinShapes(TensorType one, TensorType two) { - assert(one.getElementType() == two.getElementType()); - - if (!one.hasRank()) - return two; - if (!two.hasRank()) - return one; - - int64_t rank = one.getRank(); - if (rank != two.getRank()) - return {}; - - SmallVector join; - join.reserve(rank); - for (int64_t i = 0; i < rank; ++i) { - if (one.isDynamicDim(i)) { - join.push_back(two.getDimSize(i)); - continue; - } - if (two.isDynamicDim(i)) { - join.push_back(one.getDimSize(i)); - continue; - } - if (one.getDimSize(i) != two.getDimSize(i)) - return {}; - join.push_back(one.getDimSize(i)); - } - return RankedTensorType::get(join, one.getElementType()); -} - -namespace { - -/// Replaces chains of two tensor_cast operations by a single tensor_cast -/// operation if doing so does not remove runtime constraints. -struct ChainedTensorCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorCastOp tensorCast, - PatternRewriter &rewriter) const final { - auto tensorCastOperand = - tensorCast.getOperand().getDefiningOp(); - - if (!tensorCastOperand) - return failure(); - - auto sourceType = - tensorCastOperand.getOperand().getType().cast(); - auto intermediateType = tensorCastOperand.getType().cast(); - auto resultType = tensorCast.getType().cast(); - - // We can remove the intermediate cast if joining all three produces the - // same result as just joining the source and result shapes. - auto firstJoin = - joinShapes(joinShapes(sourceType, intermediateType), resultType); - - // The join might not exist if the cast sequence would fail at runtime. - if (!firstJoin) - return failure(); - - // The newJoin always exists if the above join exists, it might just contain - // less information. If so, we cannot drop the intermediate cast, as doing - // so would remove runtime checks. - auto newJoin = joinShapes(sourceType, resultType); - if (firstJoin != newJoin) - return failure(); - - rewriter.replaceOpWithNewOp(tensorCast, resultType, - tensorCastOperand.getOperand()); - return success(); - } -}; - -} // namespace - -void TensorCastOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -117,20 +117,6 @@ }; } // namespace -namespace { -class BufferizeTensorCastOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, operands[0]); - return success(); - } -}; -} // namespace - namespace { class BufferizeTensorFromElementsOp : public OpConversionPattern { @@ -162,7 +148,6 @@ BufferizeDimOp, BufferizeDynamicTensorFromElementsOp, BufferizeSelectOp, - BufferizeTensorCastOp, BufferizeTensorFromElementsOp // clang-format on >(typeConverter, context); @@ -180,8 +165,7 @@ target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp(); + target.addIllegalOp(); // We only bufferize the case of tensor selected type and scalar condition, // as that boils down to a select over memref descriptors (don't need to // touch the data). diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -8,12 +8,165 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::tensor; +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +/// Determines whether tensor::CastOp casts to a more dynamic version of the +/// source tensor. This is useful to fold a tensor.cast into a consuming op and +/// implement canonicalization patterns for ops in different dialects that may +/// consume the results of tensor.cast operations. Such foldable tensor.cast +/// operations are typically inserted as `subtensor` ops and are canonicalized, +/// to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked tensors with same element type and rank. +/// 2. the tensor type has more static information than the result +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = consumer %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = consumer %0 ... : tensor<8x16xf32> ... +/// ``` +bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { + if (!castOp) + return false; + + RankedTensorType sourceType = + castOp.source().getType().dyn_cast(); + RankedTensorType resultType = castOp.getType().dyn_cast(); + + // Requires RankedTensorType. + if (!sourceType || !resultType) + return false; + + // Requires same elemental type. + if (sourceType.getElementType() != resultType.getElementType()) + return false; + + // Requires same rank. + if (sourceType.getRank() != resultType.getRank()) + return false; + + // If cast is towards more static sizes along any dimension, don't fold. + for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { + if (ShapedType::isDynamic(std::get<0>(t)) && + !ShapedType::isDynamic(std::get<1>(t))) + return false; + } + + return true; +} + +bool CastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + if (!aT || !bT) + return false; + + if (aT.getElementType() != bT.getElementType()) + return false; + + return succeeded(verifyCompatibleShape(aT, bT)); +} + +OpFoldResult CastOp::fold(ArrayRef operands) { + return impl::foldCastOp(*this); +} + +/// Compute a TensorType that has the joined shape knowledge of the two +/// given TensorTypes. The element types need to match. +static TensorType joinShapes(TensorType one, TensorType two) { + assert(one.getElementType() == two.getElementType()); + + if (!one.hasRank()) + return two; + if (!two.hasRank()) + return one; + + int64_t rank = one.getRank(); + if (rank != two.getRank()) + return {}; + + SmallVector join; + join.reserve(rank); + for (int64_t i = 0; i < rank; ++i) { + if (one.isDynamicDim(i)) { + join.push_back(two.getDimSize(i)); + continue; + } + if (two.isDynamicDim(i)) { + join.push_back(one.getDimSize(i)); + continue; + } + if (one.getDimSize(i) != two.getDimSize(i)) + return {}; + join.push_back(one.getDimSize(i)); + } + return RankedTensorType::get(join, one.getElementType()); +} + +namespace { + +/// Replaces chains of two tensor.cast operations by a single tensor.cast +/// operation if doing so does not remove runtime constraints. +struct ChainedTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp tensorCast, + PatternRewriter &rewriter) const final { + auto tensorCastOperand = tensorCast.getOperand().getDefiningOp(); + + if (!tensorCastOperand) + return failure(); + + auto sourceType = + tensorCastOperand.getOperand().getType().cast(); + auto intermediateType = tensorCastOperand.getType().cast(); + auto resultType = tensorCast.getType().cast(); + + // We can remove the intermediate cast if joining all three produces the + // same result as just joining the source and result shapes. + auto firstJoin = + joinShapes(joinShapes(sourceType, intermediateType), resultType); + + // The join might not exist if the cast sequence would fail at runtime. + if (!firstJoin) + return failure(); + + // The newJoin always exists if the above join exists, it might just contain + // less information. If so, we cannot drop the intermediate cast, as doing + // so would remove runtime checks. + auto newJoin = joinShapes(sourceType, resultType); + if (firstJoin != newJoin) + return failure(); + + rewriter.replaceOpWithNewOp(tensorCast, resultType, + tensorCastOperand.getOperand()); + return success(); + } +}; + +} // namespace + +void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -19,6 +19,20 @@ using namespace mlir; +namespace { +class BufferizeCastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::CastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto resultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultType, operands[0]); + return success(); + } +}; +} // namespace + namespace { class BufferizeExtractOp : public OpConversionPattern { public: @@ -37,7 +51,7 @@ void mlir::populateTensorBufferizePatterns( MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); } namespace { @@ -49,7 +63,7 @@ ConversionTarget target(*context); populateTensorBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalDialect(); if (failed( diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1213,6 +1213,19 @@ return nullptr; } +LogicalResult +impl::verifyCastOp(Operation *op, + function_ref areCastCompatible) { + auto opType = op->getOperand(0).getType(); + auto resType = op->getResult(0).getType(); + if (!areCastCompatible(opType, resType)) + return op->emitError("operand type ") + << opType << " and result type " << resType + << " are cast incompatible"; + + return success(); +} + //===----------------------------------------------------------------------===// // Misc. utils //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -95,7 +95,7 @@ // CHECK: %[[C2:.*]] = constant 2 : index // CHECK: %[[C3:.*]] = constant 3 : index // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] - // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor + // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR3]] : tensor<3xindex> to tensor // CHECK: return %[[RESULT]] : tensor %shape = shape.const_shape [1, 2, 3] : tensor return %shape : tensor @@ -108,7 +108,7 @@ // CHECK-SAME: () -> tensor func @const_shape_zero_elements() -> tensor { // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex> - // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor + // CHECK: %[[RESULT:.*]] = tensor.cast %[[TENSOR]] : tensor<0xindex> to tensor // CHECK: return %[[RESULT]] : tensor %shape = shape.const_shape [] : tensor return %shape : tensor @@ -152,13 +152,13 @@ // ----- -// Lower `to_extent_tensor` to `std.tensor_cast` +// Lower `to_extent_tensor` to `tensor.cast` // Fold to_extent_tensor when already on tensor. // CHECK-LABEL: @to_extent_tensor // CHECK-SAME: (%[[ARG:.*]]: tensor func @to_extent_tensor(%arg: tensor) -> tensor<3xindex> { // CHECK-NOT: to_extent_tensor - // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor to tensor<3xindex + // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor to tensor<3xindex %casted = shape.to_extent_tensor %arg : tensor -> tensor<3xindex> // CHECK: return %[[RES]] return %casted : tensor<3xindex> @@ -316,8 +316,8 @@ // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor - // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor + // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor to tensor // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index @@ -356,8 +356,8 @@ // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor - // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor + // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index @@ -400,8 +400,8 @@ // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index -// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor -// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor +// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor +// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor to tensor // CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor // CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor // CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index @@ -438,8 +438,8 @@ // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index -// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor -// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor +// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor to tensor +// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor to tensor // CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor // CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor // CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -317,20 +317,20 @@ // ----- -// CHECK-LABEL: func @tensor_cast( -func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) +// CHECK-LABEL: func @tensor.cast( +func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) -> tensor<3x?xf32> { - %ta = tensor_cast %a : tensor<3x4xf32> to tensor - %tb = tensor_cast %b : tensor<4x?xf32> to tensor - %tc = tensor_cast %c : tensor<3x?xf32> to tensor + %ta = tensor.cast %a : tensor<3x4xf32> to tensor + %tb = tensor.cast %b : tensor<4x?xf32> to tensor + %tc = tensor.cast %c : tensor<3x?xf32> to tensor // CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>) // CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32> %0 = linalg.matmul ins(%ta, %tb: tensor, tensor) init(%tc: tensor) -> tensor - %1 = tensor_cast %0 : tensor to tensor<3x?xf32> + %1 = tensor.cast %0 : tensor to tensor<3x?xf32> return %1: tensor<3x?xf32> } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -872,24 +872,24 @@ // ----- -// Verify that tensor_cast folding uses the correct type -// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned -func @fold_tensor_cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { +// Verify that tensor.cast folding uses the correct type +// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned +func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { // CHECK: constant dense<2> : tensor<1xindex> - // CHECK-NOT: tensor_cast + // CHECK-NOT: tensor.cast %0 = shape.const_shape [2] : tensor - %1 = tensor_cast %0 : tensor to tensor<1xindex> + %1 = tensor.cast %0 : tensor to tensor<1xindex> return %1 : tensor<1xindex> } // ----- -// Verify that tensor_cast folding uses the correct type -// CHECK-LABEL: @fold_tensor_cast_of_const_shape_returned_dynamic -func @fold_tensor_cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor { +// Verify that tensor.cast folding uses the correct type +// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic +func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor { // CHECK: shape.const_shape [2] : tensor - // CHECK-NOT: tensor_cast + // CHECK-NOT: tensor.cast %0 = shape.const_shape [2] : tensor<1xindex> - %1 = tensor_cast %0 : tensor<1xindex> to tensor + %1 = tensor.cast %0 : tensor<1xindex> to tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -75,39 +75,6 @@ return %0 : tensor } -// CHECK-LABEL: func @tensor_cast( -// CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { -// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] -// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref to memref<2xindex> -// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]] -// CHECK: return %[[RET]] : tensor<2xindex> -func @tensor_cast(%arg0: tensor) -> tensor<2xindex> { - %0 = tensor_cast %arg0 : tensor to tensor<2xindex> - return %0 : tensor<2xindex> -} - -// CHECK-LABEL: func @tensor_cast_from_unranked( -// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { -// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32> -// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> -// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32> -// CHECK: return %[[RET]] : tensor<2xf32> -func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { - %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32> - return %0 : tensor<2xf32> -} - -// CHECK-LABEL: func @tensor_cast_to_unranked( -// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { -// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32> -// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> -// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32> -// CHECK: return %[[RET]] : tensor<*xf32> -func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { - %0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32> - return %0 : tensor<*xf32> -} - // CHECK-LABEL: func @tensor_from_elements( // CHECK-SAME: %[[ELEM0:.*]]: index, // CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,5 +1,38 @@ // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s +// CHECK-LABEL: func @tensor.cast( +// CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] +// CHECK: %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref to memref<2xindex> +// CHECK: %[[RET:.*]] = tensor_load %[[CASTED]] +// CHECK: return %[[RET]] : tensor<2xindex> +func @tensor.cast(%arg0: tensor) -> tensor<2xindex> { + %0 = tensor.cast %arg0 : tensor to tensor<2xindex> + return %0 : tensor<2xindex> +} + +// CHECK-LABEL: func @tensor.cast_from_unranked( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32> +// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32> +// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32> +// CHECK: return %[[RET]] : tensor<2xf32> +func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> { + %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @tensor.cast_to_unranked( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32> +// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32> +// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32> +// CHECK: return %[[RET]] : tensor<*xf32> +func @tensor.cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> { + %0 = tensor.cast %arg0 : tensor<2xf32> to tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[IDX:.*]]: index) -> f32 { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1,4 +1,66 @@ -// RUN: mlir-opt %s -canonicalize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// Checks that NOP casts are removed. +// CHECK-LABEL: cast_values +func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> { + // NOP cast + %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32> + // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32> + %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32> + // NOP cast + %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32> + // CHECK-NEXT: return %[[RET]] : tensor<2xi32> + return %4 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @tensor.cast_chain_ok +// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> +func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { + // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> + %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32> + %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32> + // CHECK-NEXT: return %[[RES]] + return %1 : tensor<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor.cast_chain_regain +// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> +func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { + %0 = tensor.cast %input : tensor<4xi32> to tensor + %1 = tensor.cast %0 : tensor to tensor<4xi32> + // CHECK-NEXT: return %[[IN]] + return %1 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: @tensor.cast_chain_keep +// CHECK-SAME: %[[IN:.*]]: tensor +func @tensor.cast_chain_keep(%input: tensor) -> tensor { + // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]] + %0 = tensor.cast %input : tensor to tensor<4x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]] + %1 = tensor.cast %0 : tensor<4x?xi32> to tensor + // CHECK-NEXT: return %[[C2]] + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor.cast_chain_invalid +// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> +func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]] + %0 = tensor.cast %input : tensor<4x8xi32> to tensor + // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]] + %1 = tensor.cast %0 : tensor to tensor<8x4xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<8x4xi32> +} // ----- @@ -31,3 +93,17 @@ // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]] return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32 } + +// ----- + +// CHECK-LABEL: func @extract_from_tensor.cast +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> +func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + %c0 = constant 0 : index + // CHECK-NOT: tensor.cast + %casted = tensor.cast %tensor : tensor<*xf32> to tensor + // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]] + %result = tensor.extract %casted[%c0] : tensor + return %result : f32 +} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -1,4 +1,10 @@ -// RUN: mlir-opt <%s -verify-diagnostics +// RUN: mlir-opt <%s -split-input-file -verify-diagnostics + +func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { + // expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}} + %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32> + return +} // ----- diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -1,5 +1,18 @@ // RUN: mlir-opt <%s | mlir-opt | FileCheck %s +// CHECK-LABEL: func @cast( +func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor) { + // CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor + %0 = tensor.cast %arg0 : tensor<*xf32> to tensor + // CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32> + %1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32> + // CHECK: tensor.cast %arg2 : tensor to tensor<4x?xf32> + %2 = tensor.cast %arg2 : tensor to tensor<4x?xf32> + // CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor + %3 = tensor.cast %2 : tensor<4x?xf32> to tensor + return +} + // CHECK-LABEL: func @extract( // CHECK-SAME: %[[TENSOR:.*]]: tensor, // CHECK-SAME: %[[INDEX:.*]]: index) { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -696,23 +696,6 @@ return } -// CHECK-LABEL: func @tensor_cast(%arg0 -func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor) { - // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor - %0 = tensor_cast %arg0 : tensor<*xf32> to tensor - - // CHECK: %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32> - %1 = tensor_cast %arg1 : tensor<4x4xf32> to tensor<*xf32> - - // CHECK: %2 = tensor_cast %arg2 : tensor to tensor<4x?xf32> - %2 = tensor_cast %arg2 : tensor to tensor<4x?xf32> - - // CHECK: %3 = tensor_cast %2 : tensor<4x?xf32> to tensor - %3 = tensor_cast %2 : tensor<4x?xf32> to tensor - - return -} - // CHECK-LABEL: func @memref_cast(%arg0 func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) { // CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -661,23 +661,15 @@ // Checks that NOP casts are removed. // CHECK-LABEL: cast_values -func @cast_values(%arg0: tensor<*xi32>, %arg1: memref) -> (tensor<2xi32>, memref<2xi32>) { - - // NOP casts - %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<*xi32> - %1 = memref_cast %arg1 : memref to memref - - // CHECK-NEXT: %0 = tensor_cast %arg0 : tensor<*xi32> to tensor<2xi32> - // CHECK-NEXT: %1 = memref_cast %arg1 : memref to memref<2xi32> - %2 = tensor_cast %0 : tensor<*xi32> to tensor<2xi32> +func @cast_values(%arg0: memref) -> memref<2xi32> { + // NOP cast + %1 = memref_cast %arg0 : memref to memref + // CHECK-NEXT: %[[RET:.*]] = memref_cast %arg0 : memref to memref<2xi32> %3 = memref_cast %1 : memref to memref<2xi32> - - // NOP casts - %4 = tensor_cast %2 : tensor<2xi32> to tensor<2xi32> + // NOP cast %5 = memref_cast %3 : memref<2xi32> to memref<2xi32> - - // CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32> - return %4, %5 : tensor<2xi32>, memref<2xi32> + // CHECK-NEXT: return %[[RET]] : memref<2xi32> + return %5 : memref<2xi32> } // ----- @@ -1121,61 +1113,12 @@ yield %1 : index // CHECK: : tensor<3x?x5x7x?xindex> } : tensor<3x?x?x7x?xindex> - // CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> + // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> return %0 : tensor<3x?x?x7x?xindex> } // ----- -// CHECK-LABEL: @tensor_cast_chain_ok -// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> -func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { - // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> - %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32> - %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32> - // CHECK-NEXT: return %[[RES]] - return %1 : tensor<4x8xi32> -} - -// ----- - -// CHECK-LABEL: @tensor_cast_chain_regain -// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> -func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { - %0 = tensor_cast %input : tensor<4xi32> to tensor - %1 = tensor_cast %0 : tensor to tensor<4xi32> - // CHECK-NEXT: return %[[IN]] - return %1 : tensor<4xi32> -} - -// ----- - -// CHECK-LABEL: @tensor_cast_chain_keep -// CHECK-SAME: %[[IN:.*]]: tensor -func @tensor_cast_chain_keep(%input: tensor) -> tensor { - // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] - %0 = tensor_cast %input : tensor to tensor<4x?xi32> - // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] - %1 = tensor_cast %0 : tensor<4x?xi32> to tensor - // CHECK-NEXT: return %[[C2]] - return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_cast_chain_invalid -// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> -func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { - // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] - %0 = tensor_cast %input : tensor<4x8xi32> to tensor - // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] - %1 = tensor_cast %0 : tensor to tensor<8x4xi32> - // CHECK-NEXT: return %[[C2]] - return %1 : tensor<8x4xi32> -} - -// ----- - // CHECK-LABEL: func @subtensor // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: index, %[[ARG1:[0-9a-z]*]]: index func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index) @@ -1189,30 +1132,16 @@ // CHECK: subtensor %{{.*}}[0, 0, 0] [7, 11, 2] [1, 1, 1] : // CHECK-SAME: tensor<8x16x4xf32> to tensor<7x11x2xf32> - // CHECK: tensor_cast %{{.*}} : tensor<7x11x2xf32> to tensor + // CHECK: tensor.cast %{{.*}} : tensor<7x11x2xf32> to tensor %1 = subtensor %t[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] : tensor<8x16x4xf32> to tensor // Test: subtensor with one dynamic operand can also be folded. // CHECK: subtensor %{{.*}}[0, 0, 0] [2, %[[ARG0]], 2] [1, 1, 1] : // CHECK-SAME: tensor to tensor<2x?x2xf32> - // CHECK: tensor_cast %{{.*}} : tensor<2x?x2xf32> to tensor + // CHECK: tensor.cast %{{.*}} : tensor<2x?x2xf32> to tensor %2 = subtensor %1[%c0, %c0, %c0] [%c2, %arg0, %c2] [%c1, %c1, %c1] : tensor to tensor return %2 : tensor } - -// ----- - -// CHECK-LABEL: func @extract_from_tensor_cast -// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> -func @extract_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 { - // CHECK-NEXT: %[[C0:.*]] = constant 0 : index - %c0 = constant 0 : index - // CHECK-NOT: tensor_cast - %casted = tensor_cast %tensor : tensor<*xf32> to tensor - // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]] - %result = tensor.extract %casted[%c0] : tensor - return %result : f32 -} diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -68,10 +68,10 @@ /// types. // CHECK-LABEL: @different_results func @different_results(%arg0: tensor<*xf32>) -> (tensor, tensor<4x?xf32>) { - // CHECK: %0 = tensor_cast %arg0 : tensor<*xf32> to tensor - // CHECK-NEXT: %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32> - %0 = tensor_cast %arg0 : tensor<*xf32> to tensor - %1 = tensor_cast %arg0 : tensor<*xf32> to tensor<4x?xf32> + // CHECK: %0 = tensor.cast %arg0 : tensor<*xf32> to tensor + // CHECK-NEXT: %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32> + %0 = tensor.cast %arg0 : tensor<*xf32> to tensor + %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32> // CHECK-NEXT: return %0, %1 : tensor, tensor<4x?xf32> return %0, %1 : tensor, tensor<4x?xf32> diff --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim --- a/mlir/utils/vim/syntax/mlir.vim +++ b/mlir/utils/vim/syntax/mlir.vim @@ -40,7 +40,7 @@ syn keyword mlirOps constant dealloc divf dma_start dma_wait dim exp syn keyword mlirOps getTensor index_cast load log memref_cast syn keyword mlirOps memref_shape_cast mulf muli negf powf prefetch rsqrt sitofp -syn keyword mlirOps splat store select sqrt subf subi subview tanh tensor_cast +syn keyword mlirOps splat store select sqrt subf subi subview tanh syn keyword mlirOps view " Affine ops.