diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -53,6 +53,35 @@ }]; } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +def Tensor_BitcastOp : Tensor_Op<"bitcast", [ + DeclareOpInterfaceMethods, + Pure + ]> { + let summary = "tensor bitcast operation"; + let description = [{ + Bitcast a tensor from one type to another type of equivalent element width. + If both are ranked, then the rank should be the same and static dimensions + should match. + + Example: + + ```mlir + // Bitcast from unsigned to signed or signless integer. + %2 = tensor.bitcast %1 : tensor<4xui32> to tensor<4xi32> + ``` + }]; + + let arguments = (ins AnyTensor:$source); + let results = (outs AnyTensor:$dest); + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -162,6 +162,53 @@ return droppedDims; } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + Type a = inputs.front(), b = outputs.front(); + auto aT = dyn_cast(a); + auto bT = dyn_cast(b); + if (!aT || !bT) + return false; + + if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth()) + return false; + + return succeeded(verifyCompatibleShape(aT, bT)); +} + +namespace { + +/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast +/// operation. +struct ChainedTensorBitcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BitcastOp tensorBitcast, + PatternRewriter &rewriter) const final { + auto tensorBitcastOperand = + tensorBitcast.getOperand().getDefiningOp(); + if (!tensorBitcastOperand) + return failure(); + + auto resultType = cast(tensorBitcast.getType()); + rewriter.replaceOpWithNewOp(tensorBitcast, resultType, + tensorBitcastOperand.getOperand()); + return success(); + } +}; + +} // namespace + +void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1,5 +1,28 @@ // RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s +// CHECK-LABEL: @tensor_bitcast_chain_ok +// CHECK-SAME: %[[IN:.*]]: tensor<2xi32> +func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> { + // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32> + %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32> + %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32> + // CHECK-NEXT: return %[[RES]] + return %1 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @tensor_bitcast_chain_nop +// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> +func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> { + %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32> + %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32> + // CHECK-NEXT: return %[[IN]] + return %1 : tensor<4xi32> +} + +// ----- + // Checks that NOP casts are removed. // CHECK-LABEL: cast_values func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {