diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialect MLIRIR ) diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; @@ -84,45 +86,122 @@ return b.create(loc, targetIntegerType, value); } -Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, - Type toType, bool isUnsignedCast) { - if (operand.getType() == toType) - return operand; - if (auto toIntType = dyn_cast(toType)) { - // If operand is floating point, cast directly to the int type. - if (isa(operand.getType())) { - if (isUnsignedCast) - return b.create(loc, toType, operand); - return b.create(loc, toType, operand); +static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, + IntegerType toType, bool isUnsigned) { + // If operand is floating point, cast directly to the int type. + if (isa(operand.getType())) { + if (isUnsigned) + return b.create(toType, operand); + return b.create(toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return b.create(toType, operand); + if (auto fromIntType = dyn_cast(operand.getType())) { + // Either extend or truncate. + if (toType.getWidth() > fromIntType.getWidth()) { + if (isUnsigned) + return b.create(toType, operand); + return b.create(toType, operand); } - // Cast index operands directly to the int type. - if (operand.getType().isIndex()) - return b.create(loc, toType, operand); - if (auto fromIntType = dyn_cast(operand.getType())) { - // Either extend or truncate. - if (toIntType.getWidth() > fromIntType.getWidth()) { - if (isUnsignedCast) - return b.create(loc, toType, operand); - return b.create(loc, toType, operand); + if (toType.getWidth() < fromIntType.getWidth()) + return b.create(toType, operand); + return operand; + } + + return {}; +} + +static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, + FloatType toType, bool isUnsigned) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (isa(operand.getType())) { + if (isUnsigned) + return b.create(toType, operand); + return b.create(toType, operand); + } + if (auto fromFpTy = dyn_cast(operand.getType())) { + if (toType.getWidth() > fromFpTy.getWidth()) + return b.create(toType, operand); + if (toType.getWidth() < fromFpTy.getWidth()) + return b.create(toType, operand); + return operand; + } + + return {}; +} + +static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, + ComplexType targetType, + bool isUnsigned) { + if (auto fromComplexType = dyn_cast(operand.getType())) { + if (isa(targetType.getElementType()) && + isa(fromComplexType.getElementType())) { + Value real = b.create(operand); + Value imag = b.create(operand); + Type targetETy = targetType.getElementType(); + if (targetType.getElementType().getIntOrFloatBitWidth() < + fromComplexType.getElementType().getIntOrFloatBitWidth()) { + real = b.create(targetETy, real); + imag = b.create(targetETy, imag); + } else { + real = b.create(targetETy, real); + imag = b.create(targetETy, imag); } - if (toIntType.getWidth() < fromIntType.getWidth()) - return b.create(loc, toType, operand); + return b.create(targetType, real, imag); } - } else if (auto toFloatType = dyn_cast(toType)) { - // If operand is integer, cast directly to the float type. - // Note that it is unclear how to cast from BF16<->FP16. - if (isa(operand.getType())) { - if (isUnsignedCast) - return b.create(loc, toFloatType, operand); - return b.create(loc, toFloatType, operand); + } + + if (auto fromFpType = dyn_cast(operand.getType())) { + FloatType toFpTy = cast(targetType.getElementType()); + auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); + Value from = operand; + if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { + from = b.create(toFpTy, from); } - if (auto fromFloatType = dyn_cast(operand.getType())) { - if (toFloatType.getWidth() > fromFloatType.getWidth()) - return b.create(loc, toFloatType, operand); - if (toFloatType.getWidth() < fromFloatType.getWidth()) - return b.create(loc, toFloatType, operand); + if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { + from = b.create(toFpTy, from); + } + Value zero = b.create( + mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); + return b.create(targetType, from, zero); + } + + if (auto fromIntType = dyn_cast(operand.getType())) { + FloatType toFpTy = cast(targetType.getElementType()); + Value from = operand; + if (isUnsigned) { + from = b.create(toFpTy, from); + } else { + from = b.create(toFpTy, from); } + Value zero = b.create( + mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); + return b.create(targetType, from, zero); } + + return {}; +} + +Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, + Type toType, bool isUnsignedCast) { + if (operand.getType() == toType) + return operand; + ImplicitLocOpBuilder ib(loc, b); + Value result; + if (auto intTy = dyn_cast(toType)) { + result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast); + } else if (auto floatTy = dyn_cast(toType)) { + result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast); + } else if (auto complexTy = dyn_cast(toType)) { + result = + convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast); + } + + if (result) + return result; + emitWarning(loc) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -29,9 +30,10 @@ } static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { - bool isInt = isa(x.getType()); - if (isInt) + if (isa(x.getType())) return builder.create(loc, x, y); + if (isa(x.getType())) + return builder.create(loc, x, y); return builder.create(loc, x, y); } @@ -42,6 +44,8 @@ convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false); Value yConvert = convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); + if (isa(accType)) + return builder.create(loc, xConvert, yConvert); if (isa(accType)) return builder.create(loc, xConvert, yConvert); return builder.create(loc, xConvert, yConvert); @@ -111,7 +115,7 @@ // Reshape output and filter to the LHS and result of a (B)MNK matmul. SmallVector filterReassocIndices = {{0, 1, 2}, {3}}; auto reshapedFilterType = - RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType()); + RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType()); Value reshapedFilter = rewriter.create( loc, reshapedFilterType, filter, filterReassocIndices); diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -314,3 +314,119 @@ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op transform.print %transformed {name = "transformed"}: !transform.any_op } + +// ----- + +// Check for compatible complex case. + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: @conv_complex +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex>, tensor<36x16xcomplex>) +// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex>) +// CHECK: ^bb0(%[[ARG0:.+]]: complex, %[[ARG1:.+]]: complex, %[[ARG2:.+]]: complex) +// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[ARG1]] : complex +// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex +// CHECK: linalg.yield %[[ADD]] : complex +// CHECK: } -> tensor<1x196x16xcomplex> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> +// CHECK: return %[[RESULT]] + +func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex>, %arg1: tensor<3x3x4x16xcomplex>, %arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xcomplex>, tensor<3x3x4x16xcomplex>) + outs(%arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> + return %0 : tensor<1x14x14x16xcomplex> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op + transform.print %transformed {name = "transformed"}: !transform.any_op +} + +// ----- + +// Check for compatible complex extended case. + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: @conv_complex_extended +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex>, tensor<36x16xcomplex>) +// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex>) +// CHECK: ^bb0(%[[ARG0:.+]]: complex, %[[ARG1:.+]]: complex, %[[ARG2:.+]]: complex) +// CHECK: %[[REAL:.+]] = complex.re %[[ARG1]] : complex +// CHECK: %[[IMAG:.+]] = complex.im %[[ARG1]] : complex +// CHECK: %[[REEXT:.+]] = arith.extf %[[REAL]] : f16 to f32 +// CHECK: %[[IMEXT:.+]] = arith.extf %[[IMAG]] : f16 to f32 +// CHECK: %[[COMPLEX:.+]] = complex.create %[[REEXT]], %[[IMEXT]] : complex +// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex +// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex +// CHECK: linalg.yield %[[ADD]] : complex +// CHECK: } -> tensor<1x196x16xcomplex> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> +// CHECK: return %[[RESULT]] + +func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex>, %arg1: tensor<3x3x4x16xcomplex>, %arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xcomplex>, tensor<3x3x4x16xcomplex>) + outs(%arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> + return %0 : tensor<1x14x14x16xcomplex> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op + transform.print %transformed {name = "transformed"}: !transform.any_op +} + +// ----- + +// Check for compatible complex extended case. + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: @conv_complex_f16_extended +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex>, tensor<36x16xf16>) +// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex>) +// CHECK: ^bb0(%[[ARG0:.+]]: complex, %[[ARG1:.+]]: f16, %[[ARG2:.+]]: complex) +// CHECK: %[[EXT:.+]] = arith.extf %[[ARG1]] : f16 to f32 +// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[COMPLEX:.+]] = complex.create %[[EXT]], %[[ZERO]] +// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex +// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex +// CHECK: linalg.yield %[[ADD]] : complex +// CHECK: } -> tensor<1x196x16xcomplex> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex> into tensor<1x14x14x16xcomplex> +// CHECK: return %[[RESULT]] + +func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xcomplex>, tensor<3x3x4x16xf16>) + outs(%arg2: tensor<1x14x14x16xcomplex>) -> tensor<1x14x14x16xcomplex> + return %0 : tensor<1x14x14x16xcomplex> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op + transform.print %transformed {name = "transformed"}: !transform.any_op +}