diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -125,8 +126,17 @@ return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); } Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()]; - Value paddingValue = rewriter.create( - opToPad.getLoc(), cast(paddingAttr)); + + Value paddingValue; + if (auto complexTy = dyn_cast( + getElementTypeOrSelf(opOperand->get().getType()))) { + auto complexAttr = cast(paddingAttr); + paddingValue = rewriter.create(opToPad.getLoc(), + complexTy, complexAttr); + } else { + paddingValue = rewriter.create( + opToPad.getLoc(), cast(paddingAttr)); + } // Pad the operand to the bounding box defined by `paddedShape`. auto paddedTensorType = RankedTensorType::get(