diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -48,6 +48,7 @@ let cppNamespace = "::mlir::spirv"; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -116,7 +116,7 @@ // spirv.BitcastOp //===----------------------------------------------------------------------===// -OpFoldResult spirv::BitcastOp::fold(ArrayRef /*operands*/) { +OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) { Value curInput = getOperand(); if (getType() == curInput.getType()) return curInput; @@ -139,7 +139,7 @@ // spirv.CompositeExtractOp //===----------------------------------------------------------------------===// -OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef operands) { +OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { if (auto insertOp = getComposite().getDefiningOp()) { if (getIndices() == insertOp.getIndices()) @@ -160,15 +160,14 @@ llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) { return static_cast(attr.cast().getInt()); })); - return extractCompositeElement(operands[0], indexVector); + return extractCompositeElement(adaptor.getComposite(), indexVector); } //===----------------------------------------------------------------------===// // spirv.Constant //===----------------------------------------------------------------------===// -OpFoldResult spirv::ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "spirv.Constant has no operands"); +OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); } @@ -176,8 +175,7 @@ // spirv.IAdd //===----------------------------------------------------------------------===// -OpFoldResult spirv::IAddOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "spirv.IAdd expects two operands"); +OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) { // x + 0 = x if (matchPattern(getOperand2(), m_Zero())) return getOperand1(); @@ -188,15 +186,15 @@ // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) + b; }); } //===----------------------------------------------------------------------===// // spirv.IMul //===----------------------------------------------------------------------===// -OpFoldResult spirv::IMulOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "spirv.IMul expects two operands"); +OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) { // x * 0 == 0 if (matchPattern(getOperand2(), m_Zero())) return getOperand2(); @@ -210,14 +208,15 @@ // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. return constFoldBinaryOp( - operands, [](const APInt &a, const APInt &b) { return a * b; }); + adaptor.getOperands(), + [](const APInt &a, const APInt &b) { return a * b; }); } //===----------------------------------------------------------------------===// // spirv.ISub //===----------------------------------------------------------------------===// -OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { +OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) { // x - x = 0 if (getOperand1() == getOperand2()) return Builder(getContext()).getIntegerAttr(getType(), 0); @@ -228,24 +227,23 @@ // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) - b; }); } //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===// -OpFoldResult spirv::LogicalAndOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "spirv.LogicalAnd should take two operands"); - - if (Optional rhs = getScalarOrSplatBoolAttr(operands.back())) { +OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) { + if (Optional rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { // x && true = x if (*rhs) return getOperand1(); // x && false = false if (!*rhs) - return operands.back(); + return adaptor.getOperand2(); } return Attribute(); @@ -255,11 +253,8 @@ // spirv.LogicalNotEqualOp //===----------------------------------------------------------------------===// -OpFoldResult spirv::LogicalNotEqualOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && - "spirv.LogicalNotEqual should take two operands"); - - if (Optional rhs = getScalarOrSplatBoolAttr(operands.back())) { +OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) { + if (Optional rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { // x && false = x if (!rhs.value()) return getOperand1(); @@ -284,13 +279,11 @@ // spirv.LogicalOr //===----------------------------------------------------------------------===// -OpFoldResult spirv::LogicalOrOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "spirv.LogicalOr should take two operands"); - - if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) { +OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) { + if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) { if (*rhs) // x || true = true - return operands.back(); + return adaptor.getOperand2(); // x || false = x if (!*rhs)