diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -265,7 +265,7 @@ Optional buildTensorExpFromLinalg(linalg::GenericOp op); /// Rebuilds SSA format from a tensor expression. - Value buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0, + Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1); private: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -43,8 +43,8 @@ /// Returns the equivalent of `void*` for opaque arguments to the /// execution engine. -static Type getOpaquePointerType(PatternRewriter &rewriter) { - return LLVM::LLVMPointerType::get(rewriter.getI8Type()); +static Type getOpaquePointerType(OpBuilder &builder) { + return LLVM::LLVMPointerType::get(builder.getI8Type()); } /// Returns a function reference (first hit also inserts into module). Sets @@ -81,9 +81,8 @@ /// Replaces the `op` with a `CallOp` to the function reference returned /// by `getFunc()`. -static func::CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, - Operation *op, StringRef name, - TypeRange resultType, +static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op, + StringRef name, TypeRange resultType, ValueRange operands, EmitCInterface emitCInterface) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); @@ -92,7 +91,7 @@ } /// Generates dimension size call. -static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, +static Value genDimSizeCall(OpBuilder &builder, Operation *op, SparseTensorEncodingAttr &enc, Value src, int64_t idx) { // Permute the index according to an optional dimension ordering. @@ -100,72 +99,67 @@ idx = p.getPermutedPosition(idx); // Generate the call. StringRef name = "sparseDimSize"; - SmallVector params{src, constantIndex(rewriter, op->getLoc(), idx)}; - Type iTp = rewriter.getIndexType(); - return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) + SmallVector params{src, constantIndex(builder, op->getLoc(), idx)}; + Type iTp = builder.getIndexType(); + return createFuncCall(builder, op, name, iTp, params, EmitCInterface::Off) .getResult(0); } /// Generates a call into the "swiss army knife" method of the sparse runtime /// support library for materializing sparse tensors into the computation. -static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, +static Value genNewCall(OpBuilder &builder, Operation *op, ArrayRef params) { StringRef name = "newSparseTensor"; - Type pTp = getOpaquePointerType(rewriter); - return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) + Type pTp = getOpaquePointerType(builder); + return createFuncCall(builder, op, name, pTp, params, EmitCInterface::On) .getResult(0); } /// Populates given sizes array from type. -static void sizesFromType(ConversionPatternRewriter &rewriter, - SmallVector &sizes, Location loc, - ShapedType stp) { +static void sizesFromType(OpBuilder &builder, SmallVector &sizes, + Location loc, ShapedType stp) { auto shape = stp.getShape(); for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) { uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; - sizes.push_back(constantIndex(rewriter, loc, s)); + sizes.push_back(constantIndex(builder, loc, s)); } } /// Populates given sizes array from source. -static void sizesFromSrc(ConversionPatternRewriter &rewriter, - SmallVector &sizes, Location loc, - Value src) { +static void sizesFromSrc(OpBuilder &builder, SmallVector &sizes, + Location loc, Value src) { unsigned rank = src.getType().cast().getRank(); for (unsigned i = 0; i < rank; i++) - sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); + sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); } /// Populates given sizes array from type (for static sizes) and from /// an already converted into opague pointer source (for dynamic sizes). -static void sizesFromPtr(ConversionPatternRewriter &rewriter, - SmallVector &sizes, Operation *op, - SparseTensorEncodingAttr &enc, ShapedType stp, - Value src) { +static void sizesFromPtr(OpBuilder &builder, SmallVector &sizes, + Operation *op, SparseTensorEncodingAttr &enc, + ShapedType stp, Value src) { Location loc = op->getLoc(); auto shape = stp.getShape(); for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) if (shape[i] == ShapedType::kDynamicSize) - sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); + sizes.push_back(genDimSizeCall(builder, op, enc, src, i)); else - sizes.push_back(constantIndex(rewriter, loc, shape[i])); + sizes.push_back(constantIndex(builder, loc, shape[i])); } /// Generates an uninitialized temporary buffer of the given size and /// type, but returns it as type `memref` (rather than as type /// `memref<$sz x $tp>`). -static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, - Value sz, Type tp) { +static Value genAlloca(OpBuilder &builder, Location loc, Value sz, Type tp) { auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); - return rewriter.create(loc, memTp, ValueRange{sz}); + return builder.create(loc, memTp, ValueRange{sz}); } /// Generates an uninitialized buffer of the given size and type, /// but returns it as type `memref` (rather than as type /// `memref<$sz x $tp>`). Unlike temporary buffers on the stack, /// this buffer must be explicitly deallocated by client. -static Value genAlloc(ConversionPatternRewriter &rewriter, Location loc, - Value sz, Type tp) { +static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp); return rewriter.create(loc, memTp, ValueRange{sz}); } @@ -173,27 +167,24 @@ /// Generates an uninitialized temporary buffer of the given size and /// type, but returns it as type `memref` (rather than as type /// `memref<$sz x $tp>`). -static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc, - unsigned sz, Type tp) { - return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp); +static Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp) { + return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp); } /// Generates an uninitialized temporary buffer with room for one value /// of the given type, and returns the `memref<$tp>`. -static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc, - Type tp) { - return rewriter.create(loc, MemRefType::get({}, tp)); +static Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp) { + return builder.create(loc, MemRefType::get({}, tp)); } /// Generates a temporary buffer of the given type and given contents. -static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc, - ValueRange values) { +static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { unsigned sz = values.size(); assert(sz >= 1); - Value buffer = genAlloca(rewriter, loc, sz, values[0].getType()); + Value buffer = genAlloca(builder, loc, sz, values[0].getType()); for (unsigned i = 0; i < sz; i++) { - Value idx = constantIndex(rewriter, loc, i); - rewriter.create(loc, values[i], buffer, idx); + Value idx = constantIndex(builder, loc, i); + builder.create(loc, values[i], buffer, idx); } return buffer; } @@ -201,43 +192,43 @@ /// Populates parameters required to call the "swiss army knife" method of the /// sparse runtime support library for materializing sparse tensors into the /// computation. -static void newParams(ConversionPatternRewriter &rewriter, - SmallVector ¶ms, Operation *op, - ShapedType stp, SparseTensorEncodingAttr &enc, - Action action, ValueRange szs, Value ptr = Value()) { +static void newParams(OpBuilder &builder, SmallVector ¶ms, + Operation *op, ShapedType stp, + SparseTensorEncodingAttr &enc, Action action, + ValueRange szs, Value ptr = Value()) { Location loc = op->getLoc(); ArrayRef dlt = enc.getDimLevelType(); unsigned sz = dlt.size(); // Sparsity annotations. SmallVector attrs; for (unsigned i = 0; i < sz; i++) - attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i])); - params.push_back(genBuffer(rewriter, loc, attrs)); + attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i])); + params.push_back(genBuffer(builder, loc, attrs)); // Dimension sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. - params.push_back(genBuffer(rewriter, loc, szs)); + params.push_back(genBuffer(builder, loc, szs)); // Dimension order permutation array. This is the "identity" permutation by // default, or otherwise the "reverse" permutation of a given ordering, so // that indices can be mapped quickly to the right position. SmallVector rev(sz); if (AffineMap p = enc.getDimOrdering()) { for (unsigned i = 0; i < sz; i++) - rev[p.getDimPosition(i)] = constantIndex(rewriter, loc, i); + rev[p.getDimPosition(i)] = constantIndex(builder, loc, i); } else { for (unsigned i = 0; i < sz; i++) - rev[i] = constantIndex(rewriter, loc, i); + rev[i] = constantIndex(builder, loc, i); } - params.push_back(genBuffer(rewriter, loc, rev)); + params.push_back(genBuffer(builder, loc, rev)); // Secondary and primary types encoding. Type elemTp = stp.getElementType(); - params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); - params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); - params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); + params.push_back(constantPointerTypeEncoding(builder, loc, enc)); + params.push_back(constantIndexTypeEncoding(builder, loc, enc)); + params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp)); // User action. - params.push_back(constantAction(rewriter, loc, action)); + params.push_back(constantAction(builder, loc, action)); // Payload pointer. if (!ptr) - ptr = rewriter.create(loc, getOpaquePointerType(rewriter)); + ptr = builder.create(loc, getOpaquePointerType(builder)); params.push_back(ptr); } @@ -248,17 +239,16 @@ /// addEltX call generated after is inside the if-then branch. /// if (tensor[ivs]!=0) { /// ind = ivs -static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, - Location loc, Value tensor, Value ind, - ValueRange ivs) { - Value val = rewriter.create(loc, tensor, ivs); - Value cond = genIsNonzero(rewriter, loc, val); - scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); +static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, + Value tensor, Value ind, ValueRange ivs) { + Value val = builder.create(loc, tensor, ivs); + Value cond = genIsNonzero(builder, loc, val); + scf::IfOp ifOp = builder.create(loc, cond, /*else*/ false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); unsigned i = 0; for (auto iv : ivs) { - Value idx = constantIndex(rewriter, loc, i++); - rewriter.create(loc, iv, ind, idx); + Value idx = constantIndex(builder, loc, i++); + builder.create(loc, iv, ind, idx); } return val; } @@ -276,40 +266,38 @@ /// val = a[i1,..,ik]; /// if val != 0 /// t->add(val, [i1,..,ik], [p1,..,pk]); -static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, - Type eltType, Value ptr, Value val, Value ind, - Value perm) { +static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType, + Value ptr, Value val, Value ind, Value perm) { SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)}; SmallVector params{ptr, val, ind, perm}; - Type pTp = getOpaquePointerType(rewriter); - createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); + Type pTp = getOpaquePointerType(builder); + createFuncCall(builder, op, name, pTp, params, EmitCInterface::On); } /// Generates a call to `iter->getNext()`. If there is a next element, /// then it is copied into the out-parameters `ind` and `elemPtr`, /// and the return value is true. If there isn't a next element, then /// the memory for `iter` is freed and the return value is false. -static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, - Value iter, Value ind, Value elemPtr) { +static Value genGetNextCall(OpBuilder &builder, Operation *op, Value iter, + Value ind, Value elemPtr) { Type elemTp = elemPtr.getType().cast().getElementType(); SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; SmallVector params{iter, ind, elemPtr}; - Type i1 = rewriter.getI1Type(); - return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) + Type i1 = builder.getI1Type(); + return createFuncCall(builder, op, name, i1, params, EmitCInterface::On) .getResult(0); } /// If the tensor is a sparse constant, generates and returns the pair of /// the constants for the indices and the values. static Optional> -genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc, - Value tensor) { +genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { if (auto constOp = tensor.getDefiningOp()) { if (auto attr = constOp.getValue().dyn_cast()) { DenseElementsAttr indicesAttr = attr.getIndices(); - Value indices = rewriter.create(loc, indicesAttr); + Value indices = builder.create(loc, indicesAttr); DenseElementsAttr valuesAttr = attr.getValues(); - Value values = rewriter.create(loc, valuesAttr); + Value values = builder.create(loc, valuesAttr); return std::make_pair(indices, values); } } @@ -318,26 +306,24 @@ /// Generates the code to copy the index at indices[ivs] to ind, and return /// the value at value[ivs]. -static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, - Location loc, Value indices, - Value values, Value ind, ValueRange ivs, - unsigned rank) { +static Value genIndexAndValueForSparse(OpBuilder &builder, Location loc, + Value indices, Value values, Value ind, + ValueRange ivs, unsigned rank) { for (unsigned i = 0; i < rank; i++) { - Value idx = constantIndex(rewriter, loc, i); - Value val = rewriter.create(loc, indices, - ValueRange{ivs[0], idx}); - val = - rewriter.create(loc, rewriter.getIndexType(), val); - rewriter.create(loc, val, ind, idx); + Value idx = constantIndex(builder, loc, i); + Value val = builder.create(loc, indices, + ValueRange{ivs[0], idx}); + val = builder.create(loc, builder.getIndexType(), val); + builder.create(loc, val, ind, idx); } - return rewriter.create(loc, values, ivs[0]); + return builder.create(loc, values, ivs[0]); } /// Generates code to allocate a tensor of the given type, and zero /// initialize it. If the tensor type has any dynamic sizes, then the /// `sizes` parameter should be as filled by sizesFromPtr(); that way /// we can reuse the genDimSizeCall() results generated by sizesFromPtr(). -static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc, +static Value allocDenseTensor(OpBuilder &builder, Location loc, RankedTensorType tensorTp, ValueRange sizes) { Type elemTp = tensorTp.getElementType(); auto shape = tensorTp.getShape(); @@ -347,27 +333,26 @@ if (shape[i] == ShapedType::kDynamicSize) dynamicSizes.push_back(sizes[i]); } - Value mem = rewriter.create(loc, memTp, dynamicSizes); - Value zero = constantZero(rewriter, loc, elemTp); - rewriter.create(loc, ValueRange{zero}, ValueRange{mem}); + Value mem = builder.create(loc, memTp, dynamicSizes); + Value zero = constantZero(builder, loc, elemTp); + builder.create(loc, ValueRange{zero}, ValueRange{mem}); return mem; } /// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into /// the tensor created by allocDenseTensor(). The `rank` is the rank /// of the `tensor` and the length of `ind`. -static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter, - Location loc, Value elemPtr, - Value tensor, unsigned rank, - Value ind) { +static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, + Value elemPtr, Value tensor, + unsigned rank, Value ind) { SmallVector ivs; ivs.reserve(rank); for (unsigned i = 0; i < rank; i++) { - Value idx = constantIndex(rewriter, loc, i); - ivs.push_back(rewriter.create(loc, ind, idx)); + Value idx = constantIndex(builder, loc, i); + ivs.push_back(builder.create(loc, ind, idx)); } - Value elemV = rewriter.create(loc, elemPtr); - rewriter.create(loc, elemV, tensor, ivs); + Value elemV = builder.create(loc, elemPtr); + builder.create(loc, elemV, tensor, ivs); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -400,7 +400,7 @@ /// given in Chapter 5 of "The Software Vectorization Handbook", where the /// initial scalar value is correctly embedded in the vector reduction value, /// and a straightforward horizontal reduction will complete the operation. -static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, +static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, Location loc, VectorType vtp) { Value r = codegen.redVal; switch (codegen.redKind) { @@ -409,27 +409,26 @@ case kSum: case kXor: // Initialize reduction vector to: | 0 | .. | 0 | r | - return rewriter.create( - loc, r, constantZero(rewriter, loc, vtp), - constantIndex(rewriter, loc, 0)); + return builder.create( + loc, r, constantZero(builder, loc, vtp), + constantIndex(builder, loc, 0)); case kProduct: // Initialize reduction vector to: | 1 | .. | 1 | r | - return rewriter.create( - loc, r, constantOne(rewriter, loc, vtp), - constantIndex(rewriter, loc, 0)); + return builder.create( + loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); case kAnd: case kOr: // Initialize reduction vector to: | r | .. | r | r | - return rewriter.create(loc, vtp, r); + return builder.create(loc, vtp, r); } llvm_unreachable("unknown reduction kind"); } /// Generates final value for a vector reduction. -static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, +static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, Location loc, VectorType vtp) { vector::CombiningKind kind = getCombiningKind(codegen.redKind); - return rewriter.create(loc, kind, codegen.redVal); + return builder.create(loc, kind, codegen.redVal); } /// Updates scalarized reduction value. @@ -448,7 +447,7 @@ /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), /// only nonzeroes values are used for the updates and no assumption on the /// original contents of the output buffer is necessary.. -static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, +static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, MemRefType denseTp, ArrayRef args) { Location loc = op.getLoc(); @@ -458,21 +457,21 @@ // the major advantage that the sparse kernel only updates the nonzero // positions for the output tensor. if (isInPlace(tensor)) - return rewriter.create(loc, denseTp, tensor); + return builder.create(loc, denseTp, tensor); // By default, a new buffer is allocated which is initialized to the // tensor defined in the outs() clause. This is always correct but // introduces a dense initialization component that may negatively // impact the running complexity of the sparse kernel. If the tensor // materializes into the computation, we need to preserve the zero // initialization assumption of all sparse output buffers. - Value alloc = rewriter.create(loc, denseTp, args); + Value alloc = builder.create(loc, denseTp, args); if (isMaterializing(tensor)) { - Value zero = constantZero(rewriter, loc, denseTp.getElementType()); - rewriter.create(loc, ValueRange{zero}, ValueRange{alloc}); + Value zero = constantZero(builder, loc, denseTp.getElementType()); + builder.create(loc, ValueRange{zero}, ValueRange{alloc}); } else { Value init = - rewriter.create(loc, denseTp, tensor); - rewriter.create(loc, init, alloc); + builder.create(loc, denseTp, tensor); + builder.create(loc, init, alloc); } return alloc; } @@ -480,8 +479,8 @@ /// Local bufferization of all dense and sparse data structures. /// This code enables testing the first prototype sparse compiler. // TODO: replace this with a proliferated bufferization strategy -static void genBuffers(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op) { +static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op) { Location loc = op.getLoc(); assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); // For every tensor, find lower and upper bound on dimensions, set the @@ -503,19 +502,19 @@ if (merger.isDim(tensor, idx, Dim::kSparse)) { auto dynShape = {ShapedType::kDynamicSize}; auto ptrTp = - MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); + MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); auto indTp = - MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); - Value dim = constantIndex(rewriter, loc, d); + MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + Value dim = constantIndex(builder, loc, d); // Generate sparse primitives to obtains pointer and indices. codegen.pointers[tensor][idx] = - rewriter.create(loc, ptrTp, t->get(), dim); + builder.create(loc, ptrTp, t->get(), dim); codegen.indices[tensor][idx] = - rewriter.create(loc, indTp, t->get(), dim); + builder.create(loc, indTp, t->get(), dim); } // Find upper bound in current dimension. unsigned p = perm(enc, d); - Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); + Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); if (ShapedType::isDynamic(shape[p])) args.push_back(up); assert(codegen.highs[tensor][idx] == nullptr); @@ -531,22 +530,22 @@ auto denseTp = MemRefType::get(shape, elementType); if (tensor < op.getNumInputs()) codegen.buffers[tensor] = - rewriter.create(loc, denseTp, t->get()); + builder.create(loc, denseTp, t->get()); else codegen.buffers[tensor] = - genOutputBuffer(codegen, rewriter, op, denseTp, args); + genOutputBuffer(codegen, builder, op, denseTp, args); } else if (t == codegen.sparseOut) { // True sparse output needs a lexIdx array. - Value rank = constantIndex(rewriter, loc, op.getRank(t)); + Value rank = constantIndex(builder, loc, op.getRank(t)); auto dynShape = {ShapedType::kDynamicSize}; - auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); - codegen.lexIdx = rewriter.create(loc, memTp, rank); + auto memTp = MemRefType::get(dynShape, builder.getIndexType()); + codegen.lexIdx = builder.create(loc, memTp, rank); } else { // Annotated sparse tensors. auto dynShape = {ShapedType::kDynamicSize}; auto sparseTp = MemRefType::get(dynShape, elementType); codegen.buffers[tensor] = - rewriter.create(loc, sparseTp, t->get()); + builder.create(loc, sparseTp, t->get()); } } } @@ -563,10 +562,10 @@ } /// Constructs vector iteration mask. -static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, - Value iv, Value lo, Value hi, Value step) { +static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, + Value lo, Value hi, Value step) { Location loc = iv.getLoc(); - VectorType mtp = vectorType(codegen, rewriter.getI1Type()); + VectorType mtp = vectorType(codegen, builder.getI1Type()); // Special case if the vector length evenly divides the trip count (for // example, "for i = 0, 128, 16"). A constant all-true mask is generated // so that all subsequent masked memory operations are immediately folded @@ -576,8 +575,8 @@ matchPattern(hi, m_Constant(&hiInt)) && matchPattern(step, m_Constant(&stepInt))) { if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) - return rewriter.create( - loc, mtp, constantI1(rewriter, loc, true)); + return builder.create( + loc, mtp, constantI1(builder, loc, true)); } // Otherwise, generate a vector mask that avoids overrunning the upperbound // during vector execution. Here we rely on subsequent loop optimizations to @@ -585,61 +584,61 @@ // loop into an unconditional vector loop and a scalar cleanup loop. auto minMap = AffineMap::get( /*dimCount=*/2, /*symbolCount=*/1, - {rewriter.getAffineSymbolExpr(0), - rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, - rewriter.getContext()); + {builder.getAffineSymbolExpr(0), + builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, + builder.getContext()); Value end = - rewriter.createOrFold(loc, minMap, ValueRange{hi, iv, step}); - return rewriter.create(loc, mtp, end); + builder.createOrFold(loc, minMap, ValueRange{hi, iv, step}); + return builder.create(loc, mtp, end); } /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. -static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, - Value ptr, ArrayRef args) { +static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, + ArrayRef args) { Location loc = ptr.getLoc(); VectorType vtp = vectorType(codegen, ptr); - Value pass = constantZero(rewriter, loc, vtp); + Value pass = constantZero(builder, loc, vtp); if (args.back().getType().isa()) { SmallVector scalarArgs(args.begin(), args.end()); Value indexVec = args.back(); - scalarArgs.back() = constantIndex(rewriter, loc, 0); - return rewriter.create( - loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); + scalarArgs.back() = constantIndex(builder, loc, 0); + return builder.create(loc, vtp, ptr, scalarArgs, indexVec, + codegen.curVecMask, pass); } - return rewriter.create(loc, vtp, ptr, args, - codegen.curVecMask, pass); + return builder.create(loc, vtp, ptr, args, + codegen.curVecMask, pass); } /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. -static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, - Value rhs, Value ptr, ArrayRef args) { +static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, + Value ptr, ArrayRef args) { Location loc = ptr.getLoc(); if (args.back().getType().isa()) { SmallVector scalarArgs(args.begin(), args.end()); Value indexVec = args.back(); - scalarArgs.back() = constantIndex(rewriter, loc, 0); - rewriter.create(loc, ptr, scalarArgs, indexVec, - codegen.curVecMask, rhs); + scalarArgs.back() = constantIndex(builder, loc, 0); + builder.create(loc, ptr, scalarArgs, indexVec, + codegen.curVecMask, rhs); return; } - rewriter.create(loc, ptr, args, codegen.curVecMask, - rhs); + builder.create(loc, ptr, args, codegen.curVecMask, + rhs); } /// Generates a vectorized invariant. Here we rely on subsequent loop /// optimizations to hoist the invariant broadcast out of the vector loop. -static Value genVectorInvariantValue(CodeGen &codegen, - PatternRewriter &rewriter, Value val) { +static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, + Value val) { VectorType vtp = vectorType(codegen, val.getType()); - return rewriter.create(val.getLoc(), vtp, val); + return builder.create(val.getLoc(), vtp, val); } /// Generates an affine expression. // // TODO: generalize for sparse tensor subscripts // -static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, - AffineExpr a, Location loc) { +static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, + Location loc) { switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); @@ -647,19 +646,19 @@ } case AffineExprKind::Add: { auto binOp = a.cast(); - return rewriter.create( - loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), - genAffine(codegen, rewriter, binOp.getRHS(), loc)); + return builder.create( + loc, genAffine(codegen, builder, binOp.getLHS(), loc), + genAffine(codegen, builder, binOp.getRHS(), loc)); } case AffineExprKind::Mul: { auto binOp = a.cast(); - return rewriter.create( - loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), - genAffine(codegen, rewriter, binOp.getRHS(), loc)); + return builder.create( + loc, genAffine(codegen, builder, binOp.getLHS(), loc), + genAffine(codegen, builder, binOp.getRHS(), loc)); } case AffineExprKind::Constant: { int64_t c = a.cast().getValue(); - return constantIndex(rewriter, loc, c); + return constantIndex(builder, loc, c); } default: llvm_unreachable("unexpected affine subscript"); @@ -677,7 +676,7 @@ } /// Generates subscript for load/store on a dense or sparse tensor. -static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, +static Value genSubscript(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t, SmallVector &args) { unsigned tensor = t->getOperandNumber(); @@ -695,33 +694,33 @@ } else { for (unsigned d = 0; d < rank; d++) { AffineExpr a = map.getResult(perm(enc, d)); - args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); + args.push_back(genAffine(codegen, builder, a, op.getLoc())); } } return codegen.buffers[tensor]; } /// Generates insertion code to implement dynamic tensor load. -static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, +static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t) { Location loc = op.getLoc(); // Direct lexicographic index order, tensor loads as zero. if (!codegen.expValues) { Type tp = getElementTypeOrSelf(t->get().getType()); - return constantZero(rewriter, loc, tp); + return constantZero(builder, loc, tp); } // Load from expanded access pattern. Value index = genIndex(codegen, op, t); - return rewriter.create(loc, codegen.expValues, index); + return builder.create(loc, codegen.expValues, index); } /// Generates insertion code to implement dynamic tensor store. -static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, +static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t, Value rhs) { Location loc = op.getLoc(); // Direct insertion in lexicographic index order. if (!codegen.expValues) { - rewriter.create(loc, t->get(), codegen.lexIdx, rhs); + builder.create(loc, t->get(), codegen.lexIdx, rhs); return; } // Generates insertion code along expanded access pattern. @@ -731,64 +730,62 @@ // endif // values[i] = rhs Value index = genIndex(codegen, op, t); - Value fval = constantI1(rewriter, loc, false); - Value tval = constantI1(rewriter, loc, true); + Value fval = constantI1(builder, loc, false); + Value tval = constantI1(builder, loc, true); // If statement. - Value filled = rewriter.create(loc, codegen.expFilled, index); - Value cond = rewriter.create(loc, arith::CmpIPredicate::eq, - filled, fval); - scf::IfOp ifOp = rewriter.create(loc, rewriter.getIndexType(), - cond, /*else=*/true); + Value filled = builder.create(loc, codegen.expFilled, index); + Value cond = builder.create(loc, arith::CmpIPredicate::eq, + filled, fval); + scf::IfOp ifOp = builder.create(loc, builder.getIndexType(), cond, + /*else=*/true); // True branch. - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - rewriter.create(loc, tval, codegen.expFilled, index); - rewriter.create(loc, index, codegen.expAdded, - codegen.expCount); - Value one = constantIndex(rewriter, loc, 1); - Value add = rewriter.create(loc, codegen.expCount, one); - rewriter.create(loc, add); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(loc, tval, codegen.expFilled, index); + builder.create(loc, index, codegen.expAdded, + codegen.expCount); + Value one = constantIndex(builder, loc, 1); + Value add = builder.create(loc, codegen.expCount, one); + builder.create(loc, add); // False branch. - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, codegen.expCount); - rewriter.setInsertionPointAfter(ifOp); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, codegen.expCount); + builder.setInsertionPointAfter(ifOp); // Value assignment. codegen.expCount = ifOp.getResult(0); - rewriter.create(loc, rhs, codegen.expValues, index); + builder.create(loc, rhs, codegen.expValues, index); } /// Generates a load on a dense or sparse tensor. -static Value genTensorLoad(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned exp) { +static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned exp) { // Test if the load was hoisted to a higher loop nest. Value val = merger.exp(exp).val; if (val) { if (codegen.curVecLength > 1 && !val.getType().isa()) - return genVectorInvariantValue(codegen, rewriter, val); + return genVectorInvariantValue(codegen, builder, val); return val; } // Load during insertion. OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; if (t == codegen.sparseOut) - return genInsertionLoad(codegen, rewriter, op, t); + return genInsertionLoad(codegen, builder, op, t); // Actual load. SmallVector args; - Value ptr = genSubscript(codegen, rewriter, op, t, args); + Value ptr = genSubscript(codegen, builder, op, t, args); if (codegen.curVecLength > 1) - return genVectorLoad(codegen, rewriter, ptr, args); - return rewriter.create(op.getLoc(), ptr, args); + return genVectorLoad(codegen, builder, ptr, args); + return builder.create(op.getLoc(), ptr, args); } /// Generates a store on a dense or sparse tensor. -static void genTensorStore(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned exp, Value rhs) { +static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned exp, Value rhs) { Location loc = op.getLoc(); // Test if this is a scalarized reduction. if (codegen.redVal) { if (codegen.curVecLength > 1) - rhs = rewriter.create(loc, codegen.curVecMask, rhs, - codegen.redVal); + rhs = builder.create(loc, codegen.curVecMask, rhs, + codegen.redVal); updateReduc(merger, codegen, rhs); return; } @@ -800,23 +797,23 @@ // to indicate missing output. assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); } else { - genInsertionStore(codegen, rewriter, op, t, rhs); + genInsertionStore(codegen, builder, op, t, rhs); } return; } // Actual store. SmallVector args; - Value ptr = genSubscript(codegen, rewriter, op, t, args); + Value ptr = genSubscript(codegen, builder, op, t, args); if (codegen.curVecLength > 1) - genVectorStore(codegen, rewriter, rhs, ptr, args); + genVectorStore(codegen, builder, rhs, ptr, args); else - rewriter.create(loc, rhs, ptr, args); + builder.create(loc, rhs, ptr, args); } /// Generates a pointer/index load from the sparse storage scheme. Narrower /// data types need to be zero extended before casting the value into the /// index type used for looping and indexing. -static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, +static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, Value ptr, Value s) { // See https://llvm.org/docs/GetElementPtr.html for some background on // the complications described below. @@ -833,15 +830,15 @@ // incorrect address calculations in the unlikely case we need such // extremely large offsets. Type etp = ptr.getType().cast().getElementType(); - Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); + Value vload = genVectorLoad(codegen, builder, ptr, {s}); if (!etp.isa()) { if (etp.getIntOrFloatBitWidth() < 32) - vload = rewriter.create( - loc, vectorType(codegen, rewriter.getI32Type()), vload); + vload = builder.create( + loc, vectorType(codegen, builder.getI32Type()), vload); else if (etp.getIntOrFloatBitWidth() < 64 && !codegen.options.enableSIMDIndex32) - vload = rewriter.create( - loc, vectorType(codegen, rewriter.getI64Type()), vload); + vload = builder.create( + loc, vectorType(codegen, builder.getI64Type()), vload); } return vload; } @@ -849,41 +846,40 @@ // values before casting to index without a performance penalty. Here too, // however, indices that already are 64-bit, in theory, cannot express the // full range as explained above. - Value load = rewriter.create(loc, ptr, s); + Value load = builder.create(loc, ptr, s); if (!load.getType().isa()) { if (load.getType().getIntOrFloatBitWidth() < 64) - load = rewriter.create(loc, rewriter.getI64Type(), load); + load = builder.create(loc, builder.getI64Type(), load); load = - rewriter.create(loc, rewriter.getIndexType(), load); + builder.create(loc, builder.getIndexType(), load); } return load; } /// Generates an invariant value. static Value genInvariantValue(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, unsigned exp) { + OpBuilder &builder, unsigned exp) { Value val = merger.exp(exp).val; if (codegen.curVecLength > 1) - return genVectorInvariantValue(codegen, rewriter, val); + return genVectorInvariantValue(codegen, builder, val); return val; } /// Generates an address computation "sz * p + i". -static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, - Location loc, Value size, Value p, Value i) { - Value mul = rewriter.create(loc, size, p); +static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, + Value size, Value p, Value i) { + Value mul = builder.create(loc, size, p); if (auto vtp = i.getType().dyn_cast()) { Value inv = - rewriter.create(loc, vtp.getElementType(), mul); - mul = genVectorInvariantValue(codegen, rewriter, inv); + builder.create(loc, vtp.getElementType(), mul); + mul = genVectorInvariantValue(codegen, builder, inv); } - return rewriter.create(loc, mul, i); + return builder.create(loc, mul, i); } /// Generates an index value. -static Value genIndexValue(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, unsigned exp, - unsigned ldx) { +static Value genIndexValue(Merger &merger, CodeGen &codegen, OpBuilder &builder, + unsigned exp, unsigned ldx) { unsigned idx = merger.exp(exp).index; Value ival = codegen.loops[idx]; Type itype = ival.getType(); @@ -894,28 +890,28 @@ if (vl > 1 && !itype.isa()) { Location loc = ival.getLoc(); VectorType vtp = vectorType(codegen, itype); - ival = rewriter.create(loc, vtp, ival); + ival = builder.create(loc, vtp, ival); if (idx == ldx) { Value incr; if (vtp.isScalable()) { - Type stepvty = vectorType(codegen, rewriter.getI64Type()); - Value stepv = rewriter.create(loc, stepvty); - incr = rewriter.create(loc, vtp, stepv); + Type stepvty = vectorType(codegen, builder.getI64Type()); + Value stepv = builder.create(loc, stepvty); + incr = builder.create(loc, vtp, stepv); } else { SmallVector integers; for (unsigned i = 0; i < vl; i++) integers.push_back(APInt(/*width=*/64, i)); auto values = DenseElementsAttr::get(vtp, integers); - incr = rewriter.create(loc, vtp, values); + incr = builder.create(loc, vtp, values); } - ival = rewriter.create(loc, ival, incr); + ival = builder.create(loc, ival, incr); } } return ival; } /// Recursively generates tensor expression. -static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, +static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op, unsigned exp, unsigned ldx) { Location loc = op.getLoc(); if (exp == -1u) @@ -955,10 +951,9 @@ } /// Hoists loop invariant tensor loads for which indices have been exhausted. -static void genInvariants(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned exp, unsigned ldx, bool atStart, - Kind last = Kind::kTensor) { +static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned exp, unsigned ldx, + bool atStart, Kind last = Kind::kTensor) { if (exp == -1u) return; if (merger.exp(exp).kind == Kind::kTensor) { @@ -979,7 +974,7 @@ if (lhs == t) { // Start or end a scalarized reduction if (atStart) { - Value load = genTensorLoad(merger, codegen, rewriter, op, exp); + Value load = genTensorLoad(merger, codegen, builder, op, exp); codegen.redKind = getReduction(last); codegen.redExp = exp; updateReduc(merger, codegen, load); @@ -988,12 +983,12 @@ updateReduc(merger, codegen, Value()); codegen.redExp = -1u; codegen.redKind = kNoReduc; - genTensorStore(merger, codegen, rewriter, op, exp, redVal); + genTensorStore(merger, codegen, builder, op, exp, redVal); } } else { // Start or end loop invariant hoisting of a tensor load. merger.exp(exp).val = - atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); + atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); } } else if (merger.exp(exp).kind != Kind::kInvariant && merger.exp(exp).kind != Kind::kIndex) { @@ -1003,15 +998,14 @@ Kind last = merger.exp(exp).kind; unsigned e0 = merger.exp(exp).children.e0; unsigned e1 = merger.exp(exp).children.e1; - genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); - genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); + genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); + genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); } } /// Generates an expanded access pattern in innermost dimension. -static void genExpansion(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned at, bool atStart) { +static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned at, bool atStart) { OpOperand *lhs = codegen.sparseOut; if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || at != codegen.outerParNest) @@ -1023,11 +1017,11 @@ auto dynShape = {ShapedType::kDynamicSize}; Type etp = tensor.getType().cast().getElementType(); Type t1 = MemRefType::get(dynShape, etp); - Type t2 = MemRefType::get(dynShape, rewriter.getI1Type()); - Type t3 = MemRefType::get(dynShape, rewriter.getIndexType()); - Type t4 = rewriter.getIndexType(); + Type t2 = MemRefType::get(dynShape, builder.getI1Type()); + Type t3 = MemRefType::get(dynShape, builder.getIndexType()); + Type t4 = builder.getIndexType(); auto res = - rewriter.create(loc, TypeRange({t1, t2, t3, t4}), tensor); + builder.create(loc, TypeRange({t1, t2, t3, t4}), tensor); assert(res.getNumResults() == 4); assert(!codegen.expValues); codegen.expValues = res.getResult(0); @@ -1036,9 +1030,9 @@ codegen.expCount = res.getResult(3); } else { assert(codegen.expValues); - rewriter.create(loc, tensor, codegen.lexIdx, codegen.expValues, - codegen.expFilled, codegen.expAdded, - codegen.expCount); + builder.create(loc, tensor, codegen.lexIdx, codegen.expValues, + codegen.expFilled, codegen.expAdded, + codegen.expCount); codegen.expValues = codegen.expFilled = codegen.expAdded = codegen.expCount = Value(); } @@ -1047,7 +1041,7 @@ /// Generates initialization code for the subsequent loop sequence at /// current index level. Returns true if the loop sequence needs to /// maintain the universal index. -static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, +static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, std::vector &topSort, unsigned at, BitVector &inits) { bool needsUniv = false; @@ -1067,12 +1061,12 @@ break; } Value ptr = codegen.pointers[tensor][idx]; - Value one = constantIndex(rewriter, loc, 1); - Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) + Value one = constantIndex(builder, loc, 1); + Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; - codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); - Value p1 = rewriter.create(loc, p0, one); - codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); + codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); + Value p1 = builder.create(loc, p0, one); + codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); } else { // Dense index still in play. needsUniv = true; @@ -1081,7 +1075,7 @@ } // Initialize the universal dense index. - codegen.loops[idx] = constantIndex(rewriter, loc, 0); + codegen.loops[idx] = constantIndex(builder, loc, 0); return needsUniv; } @@ -1155,10 +1149,9 @@ } /// Generates a for-loop on a single index. -static Operation *genFor(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - bool isOuter, bool isInner, unsigned idx, - BitVector &indices) { +static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, bool isOuter, bool isInner, + unsigned idx, BitVector &indices) { unsigned fb = indices.find_first(); unsigned tensor = merger.tensor(fb); assert(idx == merger.index(fb)); @@ -1178,22 +1171,22 @@ Location loc = op.getLoc(); Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; - Value step = constantIndex(rewriter, loc, codegen.curVecLength); + Value step = constantIndex(builder, loc, codegen.curVecLength); if (isVector && codegen.options.enableVLAVectorization) { - Value vscale = rewriter.create( - loc, IndexType::get(rewriter.getContext())); - step = rewriter.create(loc, vscale, step); + Value vscale = builder.create( + loc, IndexType::get(builder.getContext())); + step = builder.create(loc, vscale, step); } // Emit a parallel loop. if (isParallel) { assert(!isVector); - scf::ParallelOp parOp = rewriter.create(loc, lo, hi, step); + scf::ParallelOp parOp = builder.create(loc, lo, hi, step); if (isSparse) codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; else codegen.loops[idx] = parOp.getInductionVars()[0]; - rewriter.setInsertionPointToStart(parOp.getBody()); + builder.setInsertionPointToStart(parOp.getBody()); return parOp; } @@ -1203,14 +1196,14 @@ // In a vector loop, bring reduction into SIMD form, if not already. if (isVector && !codegen.redVal.getType().isa()) { VectorType vtp = vectorType(codegen, codegen.redVal.getType()); - Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); + Value vred = genVectorReducInit(codegen, builder, loc, vtp); updateReduc(merger, codegen, vred); } operands.push_back(codegen.redVal); } if (codegen.expValues) operands.push_back(codegen.expCount); - scf::ForOp forOp = rewriter.create(loc, lo, hi, step, operands); + scf::ForOp forOp = builder.create(loc, lo, hi, step, operands); if (codegen.redVal) updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); if (codegen.expValues) @@ -1221,21 +1214,21 @@ codegen.pidxs[tensor][idx] = iv; else codegen.loops[idx] = iv; - rewriter.setInsertionPointToStart(forOp.getBody()); + builder.setInsertionPointToStart(forOp.getBody()); // Share vector iteration mask between all subsequent loads/stores. if (isVector) - codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); + codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); return forOp; } /// Emit a while-loop for co-iteration over multiple indices. -static Operation *genWhile(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned idx, bool needsUniv, BitVector &indices) { +static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned idx, bool needsUniv, + BitVector &indices) { SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each index. - Type indexType = rewriter.getIndexType(); + Type indexType = builder.getIndexType(); for (unsigned b = 0, be = indices.size(); b < be; b++) { if (indices[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); @@ -1258,15 +1251,15 @@ } assert(types.size() == operands.size()); Location loc = op.getLoc(); - scf::WhileOp whileOp = rewriter.create(loc, types, operands); + scf::WhileOp whileOp = builder.create(loc, types, operands); SmallVector locs(types.size(), loc); - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs); - Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs); + Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); + Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); // Build the "before" region, which effectively consists // of a conjunction of "i < upper" tests on all induction. - rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); + builder.setInsertionPointToStart(&whileOp.getBefore().front()); Value cond; unsigned o = 0; for (unsigned b = 0, be = indices.size(); b < be; b++) { @@ -1275,9 +1268,9 @@ assert(idx == merger.index(b)); Value op1 = before->getArgument(o); Value op2 = codegen.highs[tensor][idx]; - Value opc = rewriter.create(loc, arith::CmpIPredicate::ult, - op1, op2); - cond = cond ? rewriter.create(loc, cond, opc) : opc; + Value opc = builder.create(loc, arith::CmpIPredicate::ult, + op1, op2); + cond = cond ? builder.create(loc, cond, opc) : opc; codegen.pidxs[tensor][idx] = after->getArgument(o++); } } @@ -1288,33 +1281,30 @@ if (needsUniv) codegen.loops[idx] = after->getArgument(o++); assert(o == operands.size()); - rewriter.create(loc, cond, before->getArguments()); - rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); + builder.create(loc, cond, before->getArguments()); + builder.setInsertionPointToStart(&whileOp.getAfter().front()); return whileOp; } /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. -static Operation *genLoop(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - std::vector &topSort, unsigned at, - bool needsUniv, BitVector &indices) { +static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, std::vector &topSort, + unsigned at, bool needsUniv, BitVector &indices) { unsigned idx = topSort[at]; if (indices.count() == 1) { bool isOuter = at == 0; bool isInner = at == topSort.size() - 1; - return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, - indices); + return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices); } - return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); + return genWhile(merger, codegen, builder, op, idx, needsUniv, indices); } /// Generates the local variables for this loop, consisting of the sparse /// indices, restored universal dense index, and dense positions. -static void genLocals(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - std::vector &topSort, unsigned at, - bool needsUniv, BitVector &locals) { +static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, std::vector &topSort, + unsigned at, bool needsUniv, BitVector &locals) { Location loc = op.getLoc(); unsigned idx = topSort[at]; @@ -1326,13 +1316,13 @@ assert(idx == merger.index(b)); Value ptr = codegen.indices[tensor][idx]; Value s = codegen.pidxs[tensor][idx]; - Value load = genLoad(codegen, rewriter, loc, ptr, s); + Value load = genLoad(codegen, builder, loc, ptr, s); codegen.idxs[tensor][idx] = load; if (!needsUniv) { if (min) { - Value cmp = rewriter.create( + Value cmp = builder.create( loc, arith::CmpIPredicate::ult, load, min); - min = rewriter.create(loc, cmp, load, min); + min = builder.create(loc, cmp, load, min); } else { min = load; } @@ -1358,32 +1348,32 @@ for (; pat != 0; pat--) if (codegen.pidxs[tensor][topSort[pat - 1]]) break; - Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) + Value p = (pat == 0) ? constantIndex(builder, loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; codegen.pidxs[tensor][idx] = genAddress( - codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); + codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]); } } // Move the insertion indices in lexicographic index order. During access // pattern expansion, we can skip setting the innermost dimension. if (codegen.sparseOut && !codegen.expValues) { - Value pos = constantIndex(rewriter, loc, at); - rewriter.create(loc, codegen.loops[idx], codegen.lexIdx, - pos); + Value pos = constantIndex(builder, loc, at); + builder.create(loc, codegen.loops[idx], codegen.lexIdx, + pos); } } /// Generates the induction structure for a while-loop. static void genWhileInduction(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, + OpBuilder &builder, linalg::GenericOp op, unsigned idx, bool needsUniv, BitVector &induction, scf::WhileOp whileOp) { Location loc = op.getLoc(); // Finalize each else branch of all if statements. if (codegen.redVal || codegen.expValues) { while (auto ifOp = dyn_cast_or_null( - rewriter.getInsertionBlock()->getParentOp())) { + builder.getInsertionBlock()->getParentOp())) { unsigned y = 0; SmallVector yields; if (codegen.redVal) { @@ -1395,11 +1385,11 @@ codegen.expCount = ifOp->getResult(y++); } assert(y == yields.size()); - rewriter.create(loc, yields); - rewriter.setInsertionPointAfter(ifOp); + builder.create(loc, yields); + builder.setInsertionPointAfter(ifOp); } } - rewriter.setInsertionPointToEnd(&whileOp.getAfter().front()); + builder.setInsertionPointToEnd(&whileOp.getAfter().front()); // Finalize the induction. Note that the induction could be performed // in the individual if-branches to avoid re-evaluating the conditions. // However, that would result in a rather elaborate forest of yield @@ -1407,7 +1397,7 @@ // after the if-statements more closely resembles code generated by TACO. unsigned o = 0; SmallVector operands; - Value one = constantIndex(rewriter, loc, 1); + Value one = constantIndex(builder, loc, 1); for (unsigned b = 0, be = induction.size(); b < be; b++) { if (induction[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); @@ -1415,10 +1405,10 @@ Value op1 = codegen.idxs[tensor][idx]; Value op2 = codegen.loops[idx]; Value op3 = codegen.pidxs[tensor][idx]; - Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, - op1, op2); - Value add = rewriter.create(loc, op3, one); - operands.push_back(rewriter.create(loc, cmp, add, op3)); + Value cmp = builder.create(loc, arith::CmpIPredicate::eq, + op1, op2); + Value add = builder.create(loc, op3, one); + operands.push_back(builder.create(loc, cmp, add, op3)); codegen.pidxs[tensor][idx] = whileOp->getResult(o++); } } @@ -1432,17 +1422,17 @@ } if (needsUniv) { operands.push_back( - rewriter.create(loc, codegen.loops[idx], one)); + builder.create(loc, codegen.loops[idx], one)); codegen.loops[idx] = whileOp->getResult(o++); } assert(o == operands.size()); - rewriter.create(loc, operands); - rewriter.setInsertionPointAfter(whileOp); + builder.create(loc, operands); + builder.setInsertionPointAfter(whileOp); } /// Generates the induction structure for a for-loop. static void genForInduction(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, + OpBuilder &builder, linalg::GenericOp op, Operation *loop) { Location loc = op.getLoc(); unsigned o = 0; @@ -1457,14 +1447,14 @@ } assert(o == operands.size()); if (o > 0) - rewriter.create(loc, operands); - rewriter.setInsertionPointAfter(loop); + builder.create(loc, operands); + builder.setInsertionPointAfter(loop); } /// Generates a single if-statement within a while-loop. -static scf::IfOp genIf(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned idx, BitVector &conditions) { +static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned idx, + BitVector &conditions) { Location loc = op.getLoc(); SmallVector types; Value cond; @@ -1476,25 +1466,25 @@ if (merger.isDim(b, Dim::kSparse)) { Value op1 = codegen.idxs[tensor][idx]; Value op2 = codegen.loops[idx]; - clause = rewriter.create(loc, arith::CmpIPredicate::eq, - op1, op2); + clause = builder.create(loc, arith::CmpIPredicate::eq, + op1, op2); } else { - clause = constantI1(rewriter, loc, true); + clause = constantI1(builder, loc, true); } - cond = cond ? rewriter.create(loc, cond, clause) : clause; + cond = cond ? builder.create(loc, cond, clause) : clause; } } if (codegen.redVal) types.push_back(codegen.redVal.getType()); if (codegen.expValues) - types.push_back(rewriter.getIndexType()); - scf::IfOp ifOp = rewriter.create(loc, types, cond, /*else=*/true); - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + types.push_back(builder.getIndexType()); + scf::IfOp ifOp = builder.create(loc, types, cond, /*else=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); return ifOp; } /// Generates end of true branch of if-statement within a while-loop. -static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, +static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, Value redInput, Value cntInput) { SmallVector operands; @@ -1507,8 +1497,8 @@ codegen.expCount = cntInput; } if (!operands.empty()) - rewriter.create(op.getLoc(), operands); - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(op.getLoc(), operands); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); } //===----------------------------------------------------------------------===// @@ -1517,21 +1507,20 @@ /// Starts a loop sequence at given level. Returns true if /// the universal loop index must be maintained at this level. -static bool startLoopSeq(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - std::vector &topSort, unsigned exp, - unsigned at, unsigned idx, unsigned ldx, +static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, std::vector &topSort, + unsigned exp, unsigned at, unsigned idx, unsigned ldx, unsigned lts) { assert(codegen.curVecLength == 1); assert(!codegen.loops[idx]); // Emit invariants at this loop sequence level. - genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); + genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); // Emit access pattern expansion for sparse tensor output. - genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true); + genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); // Emit further intitialization at this loop sequence level. unsigned l0 = merger.set(lts)[0]; bool needsUniv = - genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); + genInit(merger, codegen, builder, op, topSort, at, merger.lat(l0).bits); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. if (needsUniv) { @@ -1547,56 +1536,56 @@ /// Starts a single loop in current sequence. static Operation *startLoop(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, + OpBuilder &builder, linalg::GenericOp op, std::vector &topSort, unsigned at, unsigned li, bool needsUniv) { assert(codegen.curVecLength == 1); // Emit the for/while-loop control. - Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, + Operation *loop = genLoop(merger, codegen, builder, op, topSort, at, needsUniv, merger.lat(li).simple); // Emit the locals for this loop. - genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, + genLocals(merger, codegen, builder, op, topSort, at, needsUniv, merger.lat(li).bits); return loop; } /// Ends a single loop in current sequence. Returns new values for needsUniv. -static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, +static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, Operation *loop, unsigned idx, unsigned li, bool needsUniv) { codegen.curVecLength = 1; // End a while-loop. if (auto whileOp = dyn_cast(loop)) { - genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, + genWhileInduction(merger, codegen, builder, op, idx, needsUniv, merger.lat(li).bits, whileOp); return needsUniv; } // End a for-loop. - genForInduction(merger, codegen, rewriter, op, loop); + genForInduction(merger, codegen, builder, op, loop); return false; } /// Ends a loop sequence at given level. -static void endLoopSeq(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op, - unsigned exp, unsigned at, unsigned idx, unsigned ldx) { +static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, + linalg::GenericOp op, unsigned exp, unsigned at, + unsigned idx, unsigned ldx) { assert(codegen.curVecLength == 1); codegen.loops[idx] = Value(); // Bring a pending reduction back from SIMD form when sequence ends. if (codegen.redVal) if (auto vtp = codegen.redVal.getType().dyn_cast()) updateReduc(merger, codegen, - genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); + genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); // Unmark bookkeeping of invariants and loop index. - genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); + genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. - genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false); + genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); } /// Recursively generates code while computing iteration lattices in order /// to manage the complexity of implementing co-iteration over unions /// and intersections of sparse iterations spaces. -static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, +static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op, std::vector &topSort, unsigned exp, unsigned at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. @@ -1655,8 +1644,8 @@ } /// Converts the result computed by the sparse kernel into the required form. -static void genResult(Merger &merger, CodeGen &codegen, - PatternRewriter &rewriter, linalg::GenericOp op) { +static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, + linalg::GenericOp op) { OpOperand *lhs = op.getOutputOperand(0); Type resType = lhs->get().getType(); if (getSparseTensorEncoding(resType)) { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -825,8 +825,8 @@ return None; } -static Value insertYieldOp(PatternRewriter &rewriter, Location loc, - Region ®ion, ValueRange vals) { +static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, + ValueRange vals) { // Make a clone of overlap region. Region tmpRegion; BlockAndValueMapping mapper; @@ -842,7 +842,7 @@ return val; } -static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc, +static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, Operation *op, Value v0) { if (!v0) // Empty input value must be propagated. @@ -856,7 +856,7 @@ return insertYieldOp(rewriter, loc, presentRegion, {v0}); } -static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc, +static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Operation *op, Value v0, Value v1) { if (!v0 || !v1) // Empty input values must be propagated. @@ -870,7 +870,7 @@ return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); } -Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, +Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { case kTensor: