diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -95,6 +95,31 @@ }]; } + +//===----------------------------------------------------------------------===// +// Bitcast +//===----------------------------------------------------------------------===// + +def BitcastOp : Complex_Op<"bitcast", [Pure]> { + + let summary = "computes bitcast between complex and equal arith types"; + let description = [{ + + Example: + + ```mlir + %a = complex.bitcast %b : complex -> i64 + ``` + }]; + let assemblyFormat = "$operand attr-dict `:` type($operand) `to` type($result)"; + let arguments = (ins AnyType:$operand); + let results = (outs AnyType:$result); + + let hasCanonicalizer = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -72,6 +72,95 @@ return success(); } +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) { + if (getOperand().getType() == getType()) + return getOperand(); + + return {}; +} + +LogicalResult BitcastOp::verify() { + auto operandType = getOperand().getType(); + auto resultType = getType(); + + // We allow this to be legal as it can be folded away. + if (operandType == resultType) { + return success(); + } + + if (!operandType.isIntOrFloat() && !isa(operandType)) { + return emitOpError("operand must be int/float/complex"); + } + + if (!resultType.isIntOrFloat() && !isa(resultType)) { + return emitOpError("result must be int/float/complex"); + } + + if (isa(operandType) == isa(resultType)) { + return emitOpError("requires input or output is a complex type"); + } + + if (isa(resultType)) + std::swap(operandType, resultType); + + int32_t operandBitwidth = dyn_cast(operandType) + .getElementType() + .getIntOrFloatBitWidth() * + 2; + int32_t resultBitwidth = resultType.getIntOrFloatBitWidth(); + + if (operandBitwidth != resultBitwidth) { + return emitOpError("casting bitwidths do not match"); + } + + return success(); +} + +struct BitcastMerge final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BitcastOp op, + PatternRewriter &rewriter) const override { + if (auto defining = op.getOperand().getDefiningOp()) { + rewriter.replaceOpWithNewOp(op, op.getType(), + defining.getOperand()); + return success(); + } + + if (auto defining = op.getOperand().getDefiningOp()) { + rewriter.replaceOpWithNewOp(op, op.getType(), + defining.getOperand()); + return success(); + } + + return failure(); + } +}; + +struct ArithBitcastMerge final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::BitcastOp op, + PatternRewriter &rewriter) const override { + if (auto defining = op.getOperand().getDefiningOp()) { + rewriter.replaceOpWithNewOp(op, op.getType(), + defining.getOperand()); + return success(); + } + + return failure(); + } +}; + +void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -177,3 +177,32 @@ // CHECK-NEXT: return %[[NEG]] return %im : f32 } + +// CHECK-LABEL: func @fold +// CHECK-SAME: %[[ARG0:.*]]: complex +func.func @fold(%arg0 : complex) -> complex { + %0 = complex.bitcast %arg0 : complex to i64 + %1 = complex.bitcast %0 : i64 to complex + // CHECK: return %[[ARG0]] : complex + func.return %1 : complex +} + +// CHECK-LABEL: func @double_bitcast +// CHECK-SAME: %[[ARG0:.*]]: f64 +func.func @double_bitcast(%arg0 : f64) -> complex { + // CHECK: %[[R0:.+]] = complex.bitcast %[[ARG0]] + %0 = arith.bitcast %arg0 : f64 to i64 + %1 = complex.bitcast %0 : i64 to complex + // CHECK: return %[[R0]] : complex + func.return %1 : complex +} + +// CHECK-LABEL: func @double_reverse_bitcast +// CHECK-SAME: %[[ARG0:.*]]: complex +func.func @double_reverse_bitcast(%arg0 : complex) -> f64 { + // CHECK: %[[R0:.+]] = complex.bitcast %[[ARG0]] + %0 = complex.bitcast %arg0 : complex to i64 + %1 = arith.bitcast %0 : i64 to f64 + // CHECK: return %[[R0]] : f64 + func.return %1 : f64 +} diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir --- a/mlir/test/Dialect/Complex/invalid.mlir +++ b/mlir/test/Dialect/Complex/invalid.mlir @@ -21,3 +21,11 @@ %0 = complex.constant [1.0 : f32, -1.0 : f64] : complex return } + +// ----- + +func.func @complex_bitcast_i64(%arg0 : i64) { + // expected-error @+1 {{op requires input or output is a complex type}} + %0 = complex.bitcast %arg0: i64 to f64 + return +} diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -83,5 +83,8 @@ // CHECK: complex.tan %[[C]] : complex %tan = complex.tan %complex : complex + // CHECK: complex.bitcast %[[C]] + %i64 = complex.bitcast %complex : complex to i64 + return }