diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -73,12 +73,6 @@ DeclareOpInterfaceMethods]>, Arguments<(ins From:$in)>, Results<(outs To:$out)> { - let builders = [ - OpBuilder<(ins "Value":$source, "Type":$destType), [{ - impl::buildCastOp($_builder, $_state, source, destType); - }]> - ]; - let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -419,8 +419,7 @@ let arguments = (ins type:$arg); let results = (outs resultType:$res); let builders = [LLVM_OneResultOpBuilder]; - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; } def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -374,11 +374,6 @@ let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let results = (outs AnyRankedOrUnrankedMemRef:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; - let builders = [ - OpBuilder<(ins "Value":$source, "Type":$destType), [{ - impl::buildCastOp($_builder, $_state, source, destType); - }]> - ]; let extraClassDeclaration = [{ /// Fold the given CastOp into consumer op. @@ -388,7 +383,6 @@ }]; let hasFolder = 1; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -34,6 +34,7 @@ let results = (outs SPV_ScalarOrVectorOrCoopMatrixOf:$result ); + let assemblyFormat = "operands attr-dict `:` type($result)"; } class SPV_ArithmeticUnaryOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; // No additional verification needed in addition to the ODS-generated ones. let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td @@ -21,7 +21,9 @@ // All the operands type used in bit instructions are SPV_Integer. SPV_BinaryOp; + [NoSideEffect, SameOperandsAndResultType])> { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} class SPV_BitFieldExtractOp traits = []> : SPV_Op:$result ); - - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; } // ----- @@ -85,9 +85,9 @@ SPV_ScalarOrVectorOrPtr:$result ); - let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; - let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; - + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; let hasCanonicalizer = 1; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -72,10 +72,6 @@ SPV_ScalarOrVectorOf:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -83,7 +79,10 @@ // return type matches. class SPV_GLSLBinaryArithmeticOp traits = []> : - SPV_GLSLBinaryOp; + SPV_GLSLBinaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} // Base class for GLSL ternary ops. class SPV_GLSLTernaryArithmeticOp:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + let parser = [{ return parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ return printOneResultOp(getOperation(), p); }]; let hasVerifier = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -71,10 +71,6 @@ SPV_ScalarOrVectorOf:$result ); - let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; - - let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; - let hasVerifier = 0; } @@ -82,7 +78,10 @@ // return type matches. class SPV_OCLBinaryArithmeticOp traits = []> : - SPV_OCLBinaryOp; + SPV_OCLBinaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} // ----- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/Dialect/Shape/IR/ShapeBase.td" +include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -331,7 +332,9 @@ }]; } -def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { +def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "Creates a dimension tensor from a shape"; let description = [{ Converts a shape to a 1D integral tensor of extents. The number of elements @@ -624,7 +627,9 @@ }]; } -def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { +def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ + DeclareOpInterfaceMethods, NoSideEffect + ]> { let summary = "Casts between index types of the shape and standard dialect"; let description = [{ Converts a `shape.size` to a standard index. This operation and its diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1897,26 +1897,9 @@ }; //===----------------------------------------------------------------------===// -// Common Operation Folders/Parsers/Printers +// CastOpInterface utilities //===----------------------------------------------------------------------===// -// These functions are out-of-line implementations of the methods in UnaryOp and -// BinaryOp, which avoids them being template instantiated/duplicated. -namespace impl { -ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, - OperationState &result); - -void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, - Value rhs); -ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result); - -// Prints the given binary `op` in custom assembly form if both the two operands -// and the result have the same time. Otherwise, prints the generic assembly -// form. -void printOneResultOp(Operation *op, OpAsmPrinter &p); -} // namespace impl - // These functions are out-of-line implementations of the methods in // CastOpInterface, which avoids them being template instantiated/duplicated. namespace impl { @@ -1927,20 +1910,6 @@ /// Attempt to verify the given cast operation. LogicalResult verifyCastInterfaceOp( Operation *op, function_ref areCastCompatible); - -// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the -// need for them, but some older ODS code in `std` still depends on them). -void buildCastOp(OpBuilder &builder, OperationState &result, Value source, - Type destType); -ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); -void printCastOp(Operation *op, OpAsmPrinter &p); -// TODO: These methods are deprecated in favor of CastOpInterface. Remove them -// when all uses have been updated. Also, consider adding functionality to -// CastOpInterface to be able to perform the ChainedTensorCast canonicalization -// generically. -Value foldCastOp(Operation *op); -LogicalResult verifyCastOp(Operation *op, - function_ref areCastCompatible); } // namespace impl } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1003,11 +1003,11 @@ switch (conversion) { case PrintConversion::ZeroExt64: value = rewriter.create( - loc, value, IntegerType::get(rewriter.getContext(), 64)); + loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::SignExt64: value = rewriter.create( - loc, value, IntegerType::get(rewriter.getContext(), 64)); + loc, IntegerType::get(rewriter.getContext(), 64), value); break; case PrintConversion::None: break; diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -94,8 +94,8 @@ getMemRefType(castOp.getType().cast(), state.getOptions(), layout, sourceType.getMemorySpace()); - replaceOpWithNewBufferizedOp(rewriter, op, source, - resultType); + replaceOpWithNewBufferizedOp(rewriter, op, resultType, + source); return success(); } }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -835,15 +835,15 @@ scalingFactor); } Value numWorkersIndex = - b.create(numWorkerThreadsVal, b.getI32Type()); + b.create(b.getI32Type(), numWorkerThreadsVal); Value numWorkersFloat = - b.create(numWorkersIndex, b.getF32Type()); + b.create(b.getF32Type(), numWorkersIndex); Value scaledNumWorkers = b.create(scalingFactor, numWorkersFloat); Value scaledNumInt = - b.create(scaledNumWorkers, b.getI32Type()); + b.create(b.getI32Type(), scaledNumWorkers); Value scaledWorkers = - b.create(scaledNumInt, b.getIndexType()); + b.create(b.getIndexType(), scaledNumInt); Value maxComputeBlocks = b.create( b.create(1), scaledWorkers); diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -887,7 +887,7 @@ auto i32Vec = broadcast(builder.getI32Type(), shape); // exp2(k) - Value k = builder.create(kF32, i32Vec); + Value k = builder.create(i32Vec, kF32); Value exp2KValue = exp2I32(builder, k); // exp(x) = exp(y) * exp2(k) @@ -1042,7 +1042,7 @@ auto i32Vec = broadcast(builder.getI32Type(), shape); auto fPToSingedInteger = [&](Value a) -> Value { - return builder.create(a, i32Vec); + return builder.create(i32Vec, a); }; auto modulo4 = [&](Value a) -> Value { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -64,10 +64,6 @@ return NoneType::get(type.getContext()); } -LogicalResult memref::CastOp::verify() { - return impl::verifyCastOp(*this, areCastCompatible); -} - //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -164,7 +160,7 @@ alloc.alignmentAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = - rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); + rewriter.create(alloc.getLoc(), alloc.getType(), newAlloc); rewriter.replaceOp(alloc, {resultCast}); return success(); @@ -2155,8 +2151,8 @@ rewriter.replaceOp(subViewOp, subViewOp.source()); return success(); } - rewriter.replaceOpWithNewOp(subViewOp, subViewOp.source(), - subViewOp.getType()); + rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), + subViewOp.source()); return success(); } }; @@ -2176,7 +2172,7 @@ /// A canonicalizer wrapper to replace SubViewOps. struct SubViewCanonicalizer { void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { - rewriter.replaceOpWithNewOp(op, newOp, op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOp); } }; @@ -2421,7 +2417,7 @@ viewOp.getOperand(0), viewOp.byte_shift(), newOperands); // Insert a cast so we have the same type as the old memref type. - rewriter.replaceOpWithNewOp(viewOp, newViewOp, viewOp.getType()); + rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), newViewOp); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -101,8 +101,8 @@ Value index = rewriter.create(loc, i); size = rewriter.create(loc, op.shape(), index); if (!size.getType().isa()) - size = rewriter.create(loc, size, - rewriter.getIndexType()); + size = rewriter.create( + loc, rewriter.getIndexType(), size); sizes[i] = size; } else { sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i)); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -64,6 +64,54 @@ // Common utility functions //===----------------------------------------------------------------------===// +static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + // If the operand list is in-between parentheses, then we have a generic form. + // (see the fallback in `printOneResultOp`). + SMLoc loc = parser.getCurrentLocation(); + if (!parser.parseOptionalLParen()) { + if (parser.parseOperandList(ops) || parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(type)) + return failure(); + auto fnType = type.dyn_cast(); + if (!fnType) { + parser.emitError(loc, "expected function type"); + return failure(); + } + if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) + return failure(); + result.addTypes(fnType.getResults()); + return success(); + } + return failure(parser.parseOperandList(ops) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperands(ops, type, result.operands) || + parser.addTypeToList(type, result.types)); +} + +static void printOneResultOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumResults() == 1 && "op should have one result"); + + // If not all the operand and result types are the same, just use the + // generic assembly form to avoid omitting information in printing. + auto resultType = op->getResult(0).getType(); + if (llvm::any_of(op->getOperandTypes(), + [&](Type type) { return type != resultType; })) { + p.printGenericOp(op, /*printOpName=*/false); + return; + } + + p << ' '; + p.printOperands(op->getOperands()); + p.printOptionalAttrDict(op->getAttrs()); + // Now we can output only one type for all operands and the result. + p << " : " << resultType; +} + /// Returns true if the given op is a function-like op or nested in a /// function-like op without a module-like op in the middle. static bool isNestedInFunctionOpInterface(Operation *op) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1597,7 +1597,7 @@ // `IntegerAttr`s which makes constant folding simple. if (Attribute arg = operands[0]) return arg; - return impl::foldCastOp(*this); + return OpFoldResult(); } void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -1605,6 +1605,12 @@ patterns.add(context); } +bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + return inputs[0].isa() && outputs[0].isa(); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -1655,7 +1661,7 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { if (!operands[0]) - return impl::foldCastOp(*this); + return OpFoldResult(); Builder builder(getContext()); auto shape = llvm::to_vector<6>( operands[0].cast().getValues()); @@ -1664,6 +1670,21 @@ return DenseIntElementsAttr::get(type, shape); } +bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { + if (inputs.size() != 1 || outputs.size() != 1) + return false; + if (auto inputTensor = inputs[0].dyn_cast()) { + if (!inputTensor.getElementType().isa() || + inputTensor.getRank() != 1 || !inputTensor.isDynamicDim(0)) + return false; + } else if (!inputs[0].isa()) { + return false; + } + + TensorType outputTensor = outputs[0].dyn_cast(); + return outputTensor && outputTensor.getElementType().isa(); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -309,7 +309,7 @@ Value val = rewriter.create(loc, indices, ValueRange{ivs[0], idx}); val = - rewriter.create(loc, val, rewriter.getIndexType()); + rewriter.create(loc, rewriter.getIndexType(), val); rewriter.create(loc, val, ind, idx); } return rewriter.create(loc, values, ivs[0]); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -831,11 +831,11 @@ if (!etp.isa()) { if (etp.getIntOrFloatBitWidth() < 32) vload = rewriter.create( - loc, vload, vectorType(codegen, rewriter.getI32Type())); + loc, vectorType(codegen, rewriter.getI32Type()), vload); else if (etp.getIntOrFloatBitWidth() < 64 && !codegen.options.enableSIMDIndex32) vload = rewriter.create( - loc, vload, vectorType(codegen, rewriter.getI64Type())); + loc, vectorType(codegen, rewriter.getI64Type()), vload); } return vload; } @@ -846,9 +846,9 @@ Value load = rewriter.create(loc, ptr, s); if (!load.getType().isa()) { if (load.getType().getIntOrFloatBitWidth() < 64) - load = rewriter.create(loc, load, rewriter.getI64Type()); + load = rewriter.create(loc, rewriter.getI64Type(), load); load = - rewriter.create(loc, load, rewriter.getIndexType()); + rewriter.create(loc, rewriter.getIndexType(), load); } return load; } @@ -868,7 +868,7 @@ Value mul = rewriter.create(loc, size, p); if (auto vtp = i.getType().dyn_cast()) { Value inv = - rewriter.create(loc, mul, vtp.getElementType()); + rewriter.create(loc, vtp.getElementType(), mul); mul = genVectorInvariantValue(codegen, rewriter, inv); } return rewriter.create(loc, mul, i); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -671,25 +671,25 @@ rewriter.getZeroAttr(v0.getType())), v0); case kTruncF: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kExtF: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kCastFS: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kCastFU: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kCastSF: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kCastUF: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kCastS: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kCastU: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kTruncI: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); case kBitCast: - return rewriter.create(loc, v0, inferType(e, v0)); + return rewriter.create(loc, inferType(e, v0), v0); // Binary ops. case kMulF: return rewriter.create(loc, v0, v1); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -255,7 +255,7 @@ [&](OpBuilder &b, Location loc) { Value res = memref; if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, memref, compatibleMemRefType); + res = b.create(loc, compatibleMemRefType, memref); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), xferOp.indices().end()); @@ -271,7 +271,7 @@ alloc); b.create(loc, copyArgs.first, copyArgs.second); Value casted = - b.create(loc, alloc, compatibleMemRefType); + b.create(loc, compatibleMemRefType, alloc); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); @@ -309,7 +309,7 @@ [&](OpBuilder &b, Location loc) { Value res = memref; if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, memref, compatibleMemRefType); + res = b.create(loc, compatibleMemRefType, memref); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), xferOp.indices().end()); @@ -324,7 +324,7 @@ loc, MemRefType::get({}, vector.getType()), alloc)); Value casted = - b.create(loc, alloc, compatibleMemRefType); + b.create(loc, compatibleMemRefType, alloc); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); @@ -360,7 +360,7 @@ [&](OpBuilder &b, Location loc) { Value res = memref; if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, memref, compatibleMemRefType); + res = b.create(loc, compatibleMemRefType, memref); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), @@ -369,7 +369,7 @@ }, [&](OpBuilder &b, Location loc) { Value casted = - b.create(loc, alloc, compatibleMemRefType); + b.create(loc, compatibleMemRefType, alloc); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1125,69 +1125,7 @@ } //===----------------------------------------------------------------------===// -// BinaryOp implementation -//===----------------------------------------------------------------------===// - -// These functions are out-of-line implementations of the methods in BinaryOp, -// which avoids them being template instantiated/duplicated. - -void impl::buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, - Value rhs) { - assert(lhs.getType() == rhs.getType()); - result.addOperands({lhs, rhs}); - result.types.push_back(lhs.getType()); -} - -ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - // If the operand list is in-between parentheses, then we have a generic form. - // (see the fallback in `printOneResultOp`). - SMLoc loc = parser.getCurrentLocation(); - if (!parser.parseOptionalLParen()) { - if (parser.parseOperandList(ops) || parser.parseRParen() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColon() || parser.parseType(type)) - return failure(); - auto fnType = type.dyn_cast(); - if (!fnType) { - parser.emitError(loc, "expected function type"); - return failure(); - } - if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) - return failure(); - result.addTypes(fnType.getResults()); - return success(); - } - return failure(parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands) || - parser.addTypeToList(type, result.types)); -} - -void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) { - assert(op->getNumResults() == 1 && "op should have one result"); - - // If not all the operand and result types are the same, just use the - // generic assembly form to avoid omitting information in printing. - auto resultType = op->getResult(0).getType(); - if (llvm::any_of(op->getOperandTypes(), - [&](Type type) { return type != resultType; })) { - p.printGenericOp(op, /*printOpName=*/false); - return; - } - - p << ' '; - p.printOperands(op->getOperands()); - p.printOptionalAttrDict(op->getAttrs()); - // Now we can output only one type for all operands and the result. - p << " : " << resultType; -} - -//===----------------------------------------------------------------------===// -// CastOp implementation +// CastOpInterface //===----------------------------------------------------------------------===// /// Attempt to fold the given cast operation. @@ -1232,50 +1170,6 @@ return success(); } -void impl::buildCastOp(OpBuilder &builder, OperationState &result, Value source, - Type destType) { - result.addOperands(source); - result.addTypes(destType); -} - -ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType srcInfo; - Type srcType, dstType; - return failure(parser.parseOperand(srcInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || - parser.resolveOperand(srcInfo, srcType, result.operands) || - parser.parseKeywordType("to", dstType) || - parser.addTypeToList(dstType, result.types)); -} - -void impl::printCastOp(Operation *op, OpAsmPrinter &p) { - p << ' ' << op->getOperand(0); - p.printOptionalAttrDict(op->getAttrs()); - p << " : " << op->getOperand(0).getType() << " to " - << op->getResult(0).getType(); -} - -Value impl::foldCastOp(Operation *op) { - // Identity cast - if (op->getOperand(0).getType() == op->getResult(0).getType()) - return op->getOperand(0); - return nullptr; -} - -LogicalResult -impl::verifyCastOp(Operation *op, - function_ref areCastCompatible) { - auto opType = op->getOperand(0).getType(); - auto resType = op->getResult(0).getType(); - if (!areCastCompatible(opType, resType)) - return op->emitError("operand type ") - << opType << " and result type " << resType - << " are cast incompatible"; - - return success(); -} - //===----------------------------------------------------------------------===// // Misc. utils //===----------------------------------------------------------------------===//