diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -23,7 +23,6 @@ #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" -#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -239,17 +239,7 @@ let summary = "Convert operations from the shape dialect into the standard " "dialect"; let constructor = "mlir::createConvertShapeToStandardPass()"; - let dependentDialects = ["StandardOpsDialect"]; -} - -//===----------------------------------------------------------------------===// -// ShapeToSCF -//===----------------------------------------------------------------------===// - -def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> { - let summary = "Convert operations from the shape dialect to the SCF dialect"; - let constructor = "mlir::createConvertShapeToSCFPass()"; - let dependentDialects = ["scf::SCFDialect"]; + let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h deleted file mode 100644 --- a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_ -#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_ - -#include - -namespace mlir { - -class MLIRContext; -class FunctionPass; -class OwningRewritePatternList; - -void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx); - -std::unique_ptr createConvertShapeToSCFPass(); - -} // namespace mlir - -#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -12,7 +12,6 @@ add_subdirectory(SCFToGPU) add_subdirectory(SCFToSPIRV) add_subdirectory(SCFToStandard) -add_subdirectory(ShapeToSCF) add_subdirectory(ShapeToStandard) add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) diff --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(MLIRShapeToSCF - ShapeToSCF.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRShape - MLIRPass - MLIRSCF - MLIRTransforms - ) diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp deleted file mode 100644 --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ /dev/null @@ -1,337 +0,0 @@ -//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" - -#include "../PassDetail.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; -using namespace mlir::shape; -using namespace mlir::scf; - -namespace { -struct BroadcastOpConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(BroadcastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace - -LogicalResult BroadcastOpConverter::matchAndRewrite( - BroadcastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // For now, this lowering is only defined on `tensor` operands, not - // on shapes. - if (op.getType().isa()) - return failure(); - - assert(!op.lhs().getType().isa() && - !op.rhs().getType().isa()); - auto loc = op.getLoc(); - BroadcastOp::Adaptor transformed(operands); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - - // Find smaller and greater rank and extent tensor. - Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); - Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); - Value lhsSmaller = - rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); - Type indexTy = rewriter.getIndexType(); - Type extentTensorTy = op.getType(); - auto ifOp = rewriter.create( - loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, - lhsSmaller, - [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{lhsRank, transformed.lhs(), - rhsRank, transformed.rhs()}); - }, - [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{rhsRank, transformed.rhs(), - lhsRank, transformed.lhs()}); - }); - Value smallerRank = ifOp.getResult(0); - Value smallerOperand = ifOp.getResult(1); - Value greaterRank = ifOp.getResult(2); - Value greaterOperand = ifOp.getResult(3); - - // Allocate stack memory for the broadcasted extent tensor. - Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); - Value mem = rewriter.create(loc, memTy, ValueRange{greaterRank}); - - // Copy extents from greater operand that are not challenged. - Value rankDiff = - rewriter.create(loc, indexTy, greaterRank, smallerRank); - rewriter.create(loc, zero, rankDiff, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange) { - Value extent = b.create( - loc, greaterOperand, ValueRange{iv}); - b.create(loc, extent, mem, ValueRange{iv}); - b.create(loc); - }); - - // Determine remaining broadcasted extents. - rewriter.create( - loc, rankDiff, greaterRank, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange) { - Value greaterOperandExtent = - b.create(loc, greaterOperand, ValueRange{iv}); - Value greaterOperandExtentIsOne = - b.create(loc, CmpIPredicate::eq, greaterOperandExtent, one); - auto ifOp = b.create( - loc, TypeRange{indexTy}, greaterOperandExtentIsOne, - [&](OpBuilder &b, Location loc) { - Value ivShifted = b.create(loc, indexTy, iv, rankDiff); - Value smallerOperandExtent = b.create( - loc, smallerOperand, ValueRange{ivShifted}); - b.create(loc, smallerOperandExtent); - }, - [&](OpBuilder &b, Location loc) { - b.create(loc, greaterOperandExtent); - }); - Value extent = ifOp.getResult(0); - b.create(loc, extent, mem, ValueRange{iv}); - b.create(loc); - }); - - // Load broadcasted shape as an extent tensor. - rewriter.replaceOpWithNewOp(op, mem); - return success(); -} - -namespace { -/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is -/// only defined on `tensor` operands. The test for equality first -/// compares their size and, if equal, checks every extent for equality. -/// -/// Example: -/// -/// %result = shape.shape_eq %a, %b : tensor, tensor -/// -/// becomes -/// -/// %c0 = constant 0 : index -/// %0 = dim %arg0, %c0 : tensor -/// %1 = dim %arg1, %c0 : tensor -/// %2 = cmpi "eq", %0, %1 : index -/// %result = scf.if %2 -> (i1) { -/// %c1 = constant 1 : index -/// %true = constant true -/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { -/// %5 = extract_element %arg0[%arg2] : tensor -/// %6 = extract_element %arg1[%arg2] : tensor -/// %7 = cmpi "eq", %5, %6 : index -/// %8 = and %arg3, %7 : i1 -/// scf.yield %8 : i1 -/// } -/// scf.yield %4 : i1 -/// } else { -/// %false = constant false -/// scf.yield %false : i1 -/// } -/// -struct ShapeEqOpConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ShapeEqOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace - -LogicalResult -ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // For now, this lowering is only defined on `tensor` operands, not - // on shapes. - if (op.lhs().getType().isa() || - op.rhs().getType().isa()) { - return failure(); - } - - ShapeEqOp::Adaptor transformed(operands); - auto loc = op.getLoc(); - Type indexTy = rewriter.getIndexType(); - Value zero = rewriter.create(loc, 0); - Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); - Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); - Value eqRank = - rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); - Type i1Ty = rewriter.getI1Type(); - rewriter.replaceOpWithNewOp( - op, i1Ty, eqRank, - [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); - Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); - auto loop = b.create( - loc, zero, lhsRank, one, ValueRange{init}, - [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { - Value conj = args[0]; - Value lhsExtent = - b.create(loc, transformed.lhs(), iv); - Value rhsExtent = - b.create(loc, transformed.rhs(), iv); - Value eqExtent = b.create(loc, CmpIPredicate::eq, - lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); - b.create(loc, ValueRange({conjNext})); - }); - b.create(loc, loop.getResults()); - }, - [&](OpBuilder &b, Location loc) { - Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); - b.create(loc, result); - }); - return success(); -} - -namespace { -/// Converts `shape.reduce` to `scf.for`. -struct ReduceOpConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(shape::ReduceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; -} // namespace - -LogicalResult -ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // For now, this lowering is only defined on `tensor` operands. - if (op.shape().getType().isa()) - return failure(); - - auto loc = op.getLoc(); - shape::ReduceOp::Adaptor transformed(operands); - - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - Type indexTy = rewriter.getIndexType(); - Value rank = rewriter.create(loc, indexTy, transformed.shape(), zero); - - auto loop = rewriter.create( - loc, zero, rank, one, op.initVals(), - [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = b.create(loc, transformed.shape(), iv); - - SmallVector mappedValues{iv, extent}; - mappedValues.append(args.begin(), args.end()); - - BlockAndValueMapping mapping; - Block *reduceBody = op.getBody(); - mapping.map(reduceBody->getArguments(), mappedValues); - for (auto &nested : reduceBody->without_terminator()) - b.clone(nested, mapping); - - SmallVector mappedResults; - for (auto result : reduceBody->getTerminator()->getOperands()) - mappedResults.push_back(mapping.lookup(result)); - b.create(loc, mappedResults); - }); - - rewriter.replaceOp(op, loop.getResults()); - return success(); -} - -namespace { -/// Converts `shape_of` to for loop for unranked tensors. -class ShapeOfOpConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ShapeOfOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace - -LogicalResult -ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // For now, this lowering supports only error-free arguments. - if (op.getType().isa()) - return failure(); - - // For ranked tensors `shape_of` lowers to `std` and the pattern can be - // found in the corresponding pass. - ShapeOfOp::Adaptor transformed(operands); - Value arg = transformed.arg(); - Type argTy = arg.getType(); - if (argTy.isa()) - return failure(); - - // Allocate stack memory. - auto loc = op.getLoc(); - Value rank = rewriter.create(loc, arg); - Type indexTy = rewriter.getIndexType(); - Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); - Value mem = rewriter.create(loc, memTy, ValueRange{rank}); - - // Copy shape extents to stack-allocated memory. - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - rewriter.create( - loc, zero, rank, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value dim = rewriter.create(loc, arg, iv); - rewriter.create(loc, dim, mem, ValueRange{iv}); - rewriter.create(loc); - }); - - // Load extents to tensor value. - rewriter.replaceOpWithNewOp(op.getOperation(), mem); - return success(); -} - -namespace { -struct ConvertShapeToSCFPass - : public ConvertShapeToSCFBase { - void runOnFunction() override; -}; -} // namespace - -void ConvertShapeToSCFPass::runOnFunction() { - MLIRContext &ctx = getContext(); - - // Populate conversion patterns. - OwningRewritePatternList patterns; - populateShapeToSCFConversionPatterns(patterns, &ctx); - - // Setup target legality. - ConversionTarget target(getContext()); - target.addLegalDialect(); - - // Apply conversion. - if (failed(applyPartialConversion(getFunction(), target, patterns))) - signalPassFailure(); -} - -void mlir::populateShapeToSCFConversionPatterns( - OwningRewritePatternList &patterns, MLIRContext *ctx) { - // clang-format off - patterns.insert< - BroadcastOpConverter, - ShapeEqOpConverter, - ReduceOpConverter, - ShapeOfOpConverter>(ctx); - // clang-format on -} - -std::unique_ptr mlir::createConvertShapeToSCFPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -12,10 +12,12 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::shape; +using namespace mlir::scf; /// Conversion patterns. namespace { @@ -63,67 +65,94 @@ } // namespace namespace { -class ConstSizeOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ConstSizeOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); - return success(); - } -}; -} // namespace - -namespace { -class ShapeOfOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +struct BroadcastOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ShapeOfOp op, ArrayRef operands, + matchAndRewrite(BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace -LogicalResult ShapeOfOpConversion::matchAndRewrite( - ShapeOfOp op, ArrayRef operands, +LogicalResult BroadcastOpConverter::matchAndRewrite( + BroadcastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - - // For now, only error-free types are supported by this lowering. + // For now, this lowering is only defined on `tensor` operands, not + // on shapes. if (op.getType().isa()) return failure(); - // For unranked tensors `shape_of` lowers to `scf` and the pattern can be - // found in the corresponding pass. - ShapeOfOp::Adaptor transformed(operands); - Value tensorVal = transformed.arg(); - Type tensorTy = tensorVal.getType(); - if (tensorTy.isa()) - return failure(); - - // Build values for individual dimensions. - SmallVector dimValues; - RankedTensorType rankedTensorTy = tensorTy.cast(); - int64_t rank = rankedTensorTy.getRank(); + assert(!op.lhs().getType().isa() && + !op.rhs().getType().isa()); auto loc = op.getLoc(); - for (int64_t i = 0; i < rank; i++) { - if (rankedTensorTy.isDynamicDim(i)) { - Value dimVal = rewriter.create(loc, tensorVal, i); - dimValues.push_back(dimVal); - } else { - int64_t dim = rankedTensorTy.getDimSize(i); - Value dimVal = rewriter.create(loc, dim); - dimValues.push_back(dimVal); - } - } - - // Materialize extent tensor. - Value staticExtentTensor = - rewriter.create(loc, dimValues); - rewriter.replaceOpWithNewOp(op, staticExtentTensor, - op.getType()); + BroadcastOp::Adaptor transformed(operands); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + // Find smaller and greater rank and extent tensor. + Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); + Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); + Value lhsSmaller = + rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); + Type indexTy = rewriter.getIndexType(); + Type extentTensorTy = op.getType(); + auto ifOp = rewriter.create( + loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, + lhsSmaller, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{lhsRank, transformed.lhs(), + rhsRank, transformed.rhs()}); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{rhsRank, transformed.rhs(), + lhsRank, transformed.lhs()}); + }); + Value smallerRank = ifOp.getResult(0); + Value smallerOperand = ifOp.getResult(1); + Value greaterRank = ifOp.getResult(2); + Value greaterOperand = ifOp.getResult(3); + + // Allocate stack memory for the broadcasted extent tensor. + Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); + Value mem = rewriter.create(loc, memTy, ValueRange{greaterRank}); + + // Copy extents from greater operand that are not challenged. + Value rankDiff = + rewriter.create(loc, indexTy, greaterRank, smallerRank); + rewriter.create(loc, zero, rankDiff, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange) { + Value extent = b.create( + loc, greaterOperand, ValueRange{iv}); + b.create(loc, extent, mem, ValueRange{iv}); + b.create(loc); + }); + + // Determine remaining broadcasted extents. + rewriter.create( + loc, rankDiff, greaterRank, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange) { + Value greaterOperandExtent = + b.create(loc, greaterOperand, ValueRange{iv}); + Value greaterOperandExtentIsOne = + b.create(loc, CmpIPredicate::eq, greaterOperandExtent, one); + auto ifOp = b.create( + loc, TypeRange{indexTy}, greaterOperandExtentIsOne, + [&](OpBuilder &b, Location loc) { + Value ivShifted = b.create(loc, indexTy, iv, rankDiff); + Value smallerOperandExtent = b.create( + loc, smallerOperand, ValueRange{ivShifted}); + b.create(loc, smallerOperandExtent); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, greaterOperandExtent); + }); + Value extent = ifOp.getResult(0); + b.create(loc, extent, mem, ValueRange{iv}); + b.create(loc); + }); + + // Load broadcasted shape as an extent tensor. + rewriter.replaceOpWithNewOp(op, mem); return success(); } @@ -161,26 +190,23 @@ } namespace { -class ToExtentTensorOpConversion - : public OpConversionPattern { +class ConstSizeOpConversion : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - ToExtentTensorOpAdaptor adaptor(operands); - - if (!adaptor.input().getType().isa()) - return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); - - rewriter.replaceOpWithNewOp(op, adaptor.input(), - op.getType()); - return success(); - } + matchAndRewrite(ConstSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; } // namespace +LogicalResult ConstSizeOpConversion::matchAndRewrite( + ConstSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); + return success(); +} + namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -239,6 +265,236 @@ return success(); } +namespace { +/// Converts `shape.reduce` to `scf.for`. +struct ReduceOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // For now, this lowering is only defined on `tensor` operands. + if (op.shape().getType().isa()) + return failure(); + + auto loc = op.getLoc(); + shape::ReduceOp::Adaptor transformed(operands); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + Type indexTy = rewriter.getIndexType(); + Value rank = rewriter.create(loc, indexTy, transformed.shape(), zero); + + auto loop = rewriter.create( + loc, zero, rank, one, op.initVals(), + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + Value extent = b.create(loc, transformed.shape(), iv); + + SmallVector mappedValues{iv, extent}; + mappedValues.append(args.begin(), args.end()); + + BlockAndValueMapping mapping; + Block *reduceBody = op.getBody(); + mapping.map(reduceBody->getArguments(), mappedValues); + for (auto &nested : reduceBody->without_terminator()) + b.clone(nested, mapping); + + SmallVector mappedResults; + for (auto result : reduceBody->getTerminator()->getOperands()) + mappedResults.push_back(mapping.lookup(result)); + b.create(loc, mappedResults); + }); + + rewriter.replaceOp(op, loop.getResults()); + return success(); +} + +namespace { +/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is +/// only defined on `tensor` operands. The test for equality first +/// compares their size and, if equal, checks every extent for equality. +/// +/// Example: +/// +/// %result = shape.shape_eq %a, %b : tensor, tensor +/// +/// becomes +/// +/// %c0 = constant 0 : index +/// %0 = dim %arg0, %c0 : tensor +/// %1 = dim %arg1, %c0 : tensor +/// %2 = cmpi "eq", %0, %1 : index +/// %result = scf.if %2 -> (i1) { +/// %c1 = constant 1 : index +/// %true = constant true +/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { +/// %5 = extract_element %arg0[%arg2] : tensor +/// %6 = extract_element %arg1[%arg2] : tensor +/// %7 = cmpi "eq", %5, %6 : index +/// %8 = and %arg3, %7 : i1 +/// scf.yield %8 : i1 +/// } +/// scf.yield %4 : i1 +/// } else { +/// %false = constant false +/// scf.yield %false : i1 +/// } +/// +struct ShapeEqOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShapeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult +ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // For now, this lowering is only defined on `tensor` operands, not + // on shapes. + if (op.lhs().getType().isa() || + op.rhs().getType().isa()) { + return failure(); + } + + ShapeEqOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + Type indexTy = rewriter.getIndexType(); + Value zero = rewriter.create(loc, 0); + Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); + Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); + Value eqRank = + rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); + Type i1Ty = rewriter.getI1Type(); + rewriter.replaceOpWithNewOp( + op, i1Ty, eqRank, + [&](OpBuilder &b, Location loc) { + Value one = b.create(loc, 1); + Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); + auto loop = b.create( + loc, zero, lhsRank, one, ValueRange{init}, + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value conj = args[0]; + Value lhsExtent = + b.create(loc, transformed.lhs(), iv); + Value rhsExtent = + b.create(loc, transformed.rhs(), iv); + Value eqExtent = b.create(loc, CmpIPredicate::eq, + lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); + b.create(loc, ValueRange({conjNext})); + }); + b.create(loc, loop.getResults()); + }, + [&](OpBuilder &b, Location loc) { + Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); + b.create(loc, result); + }); + return success(); +} + +namespace { +class ShapeOfOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult ShapeOfOpConversion::matchAndRewrite( + ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + // For now, only error-free types are supported by this lowering. + if (op.getType().isa()) + return failure(); + + // For ranked tensor arguments, lower to `tensor_from_elements`. + ShapeOfOp::Adaptor transformed(operands); + Value tensor = transformed.arg(); + Type tensorTy = tensor.getType(); + if (tensorTy.isa()) { + + // Build values for individual extents. + SmallVector extentValues; + RankedTensorType rankedTensorTy = tensorTy.cast(); + int64_t rank = rankedTensorTy.getRank(); + auto loc = op.getLoc(); + for (int64_t i = 0; i < rank; i++) { + if (rankedTensorTy.isDynamicDim(i)) { + Value extent = rewriter.create(loc, tensor, i); + extentValues.push_back(extent); + } else { + Value extent = + rewriter.create(loc, rankedTensorTy.getDimSize(i)); + extentValues.push_back(extent); + } + } + + // Materialize extent tensor. + Value staticExtentTensor = + rewriter.create(loc, extentValues); + rewriter.replaceOpWithNewOp(op, staticExtentTensor, + op.getType()); + return success(); + } + + // Allocate stack memory. + auto loc = op.getLoc(); + Value rank = rewriter.create(loc, tensor); + Type indexTy = rewriter.getIndexType(); + Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); + Value mem = rewriter.create(loc, memTy, ValueRange{rank}); + + // Copy shape extents to stack-allocated memory. + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + rewriter.create( + loc, zero, rank, one, llvm::None, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + Value dim = rewriter.create(loc, tensor, iv); + rewriter.create(loc, dim, mem, ValueRange{iv}); + rewriter.create(loc); + }); + + // Load extents to tensor value. + rewriter.replaceOpWithNewOp(op.getOperation(), mem); + return success(); +} + +namespace { +class ToExtentTensorOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ToExtentTensorOpAdaptor adaptor(operands); + + if (!adaptor.input().getType().isa()) + return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); + + rewriter.replaceOpWithNewOp(op, adaptor.input(), + op.getType()); + return success(); + } +}; +} // namespace + namespace { /// Conversion pass. class ConvertShapeToStandardPass @@ -252,7 +508,7 @@ // Setup target legality. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); // Setup conversion patterns. @@ -271,11 +527,14 @@ patterns.insert< AnyOpConversion, BinaryOpConversion, - ConstShapeOpConverter, BinaryOpConversion, + BroadcastOpConverter, + ConstShapeOpConverter, ConstSizeOpConversion, GetExtentOpConverter, RankOpConverter, + ReduceOpConverter, + ShapeEqOpConverter, ShapeOfOpConversion, ToExtentTensorOpConversion>(ctx); // clang-format on diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir deleted file mode 100644 --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ /dev/null @@ -1,132 +0,0 @@ -// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s - -// CHECK-LABEL: @shape_reduce -// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index -func @shape_reduce(%shape : tensor) -> index { - %init = constant 1 : index - %num_elements = shape.reduce(%shape, %init) : tensor -> index { - ^bb0(%index : index, %extent : index, %acc: index): - %new_acc = muli %acc, %extent : index - shape.yield %new_acc : index - } - return %num_elements : index -} -// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index -// CHECK-NEXT: %[[C0:.*]] = constant 0 : index -// CHECK-NEXT: %[[C1:.*]] = constant 1 : index -// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor -// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index) -// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]] -// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index -// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index -// CHECK-NEXT: } -// CHECK-NEXT: return %[[RESULT]] : index - -// ----- - -// Don't lower `shape_of` for result type of `shape.shape`. -// CHECK-LABEL: @shape_of -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func @shape_of(%arg : tensor<*xf32>) { - // CHECK: shape.shape - %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape - return -} - -// ----- - -// Lower `shape_of` for unranked tensors. -// CHECK-LABEL: @shape_of_unranked -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func @shape_of_unranked(%arg : tensor<*xf32>) { - // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> - // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { - // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> - // CHECK: store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref - // CHECK: } - // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref - %shape = shape.shape_of %arg : tensor<*xf32> -> tensor - return -} - -// ----- - -// CHECK-LABEL: @shape_eq -// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor) -> i1 -func @shape_eq(%a : tensor, %b : tensor) -> i1 { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor - // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor - // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]] - // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[INIT:.*]] = constant true - // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { - // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor - // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor - // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]] - // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] - // CHECK: scf.yield %[[CONJ_NEXT]] : i1 - // CHECK: } - // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 - // CHECK: } else { - // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false - // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 - // CHECK: } - // CHECK: return %[[SHAPE_EQ]] : i1 - %result = shape.shape_eq %a, %b : tensor, tensor - return %result : i1 -} - -// ----- - -// Don't lower `shape.broadcast` if a `shape.shape` type is involved. -// CHECK-LABEL: @broadcast -func @broadcast(%a : tensor, %b : !shape.shape) -> !shape.shape { - // CHECK: shape.broadcast - %c = shape.broadcast %a, %b : tensor, !shape.shape -> !shape.shape - return %c : !shape.shape -} - -// ----- - -// CHECK-LABEL: @broadcast -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) -func @broadcast(%a : tensor, %b : tensor) { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor - // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor - // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] - // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor, index, tensor) { - // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor, index, tensor - // CHECK: } else { - // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor, index, tensor - // CHECK: } - // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref - // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index - // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { - // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor - // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref - // CHECK: } - // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] { - // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor - // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index - // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { - // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index - // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor - // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index - // CHECK: } else { - // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index - // CHECK: } - // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref - // CHECK: } - // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref - %0 = shape.broadcast %a, %b - : tensor, tensor -> tensor - return -} - diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -26,46 +26,6 @@ // ----- -// Don't lower `shape_of` with `shape.shape` type. -// CHECK-LABEL: @shape_of -// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) -func @shape_of_stat(%arg : tensor<1x2x3xf32>) { - // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape - %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape - return -} - -// ----- - -// Lower `shape_of` for statically shaped tensor. -// CHECK-LABEL: @shape_of_stat -// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) -func @shape_of_stat(%arg : tensor<1x2x3xf32>) { - // CHECK-DAG: %[[C1:.*]] = constant 1 : index - // CHECK-DAG: %[[C2:.*]] = constant 2 : index - // CHECK-DAG: %[[C3:.*]] = constant 3 : index - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> - %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor - return -} - -// ----- - -// Lower `shape_of` for dynamically shaped tensor. -// CHECK-LABEL: @shape_of_dyn -// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) -func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { - // CHECK-DAG: %[[C1:.*]] = constant 1 : index - // CHECK-DAG: %[[C5:.*]] = constant 5 : index - // CHECK-DAG: %[[C2:.*]] = constant 2 : index - // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> - // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> - %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor - return -} - -// ----- - // Convert `rank` to `dim` of the first dimension. // CHECK-LABEL: @rank // CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index @@ -190,3 +150,174 @@ // CHECK: return %[[RES]] return %casted : tensor<3xindex> } + +// CHECK-LABEL: @shape_reduce +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index +func @shape_reduce(%shape : tensor) -> index { + %init = constant 1 : index + %num_elements = shape.reduce(%shape, %init) : tensor -> index { + ^bb0(%index : index, %extent : index, %acc: index): + %new_acc = muli %acc, %extent : index + shape.yield %new_acc : index + } + return %num_elements : index +} +// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index +// CHECK-NEXT: %[[C1:.*]] = constant 1 : index +// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor +// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index) +// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]] +// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index +// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RESULT]] : index + +// ----- + +// Don't lower `shape_of` for result type of `shape.shape`. +// CHECK-LABEL: @shape_of +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @shape_of(%arg : tensor<*xf32>) { + // CHECK: shape.shape + %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape + return +} + +// ----- + +// Lower `shape_of` for unranked tensors. +// CHECK-LABEL: @shape_of_unranked +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @shape_of_unranked(%arg : tensor<*xf32>) { + // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> + // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { + // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> + // CHECK: store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref + // CHECK: } + // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref + %shape = shape.shape_of %arg : tensor<*xf32> -> tensor + return +} + +// ----- + +// Don't lower `shape_of` with `shape.shape` type. +// CHECK-LABEL: @shape_of +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) +func @shape_of_stat(%arg : tensor<1x2x3xf32>) { + // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape + %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape + return +} + +// ----- + +// Lower `shape_of` for statically shaped tensor. +// CHECK-LABEL: @shape_of_stat +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) +func @shape_of_stat(%arg : tensor<1x2x3xf32>) { + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[C3:.*]] = constant 3 : index + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor + return +} + +// ----- + +// Lower `shape_of` for dynamically shaped tensor. +// CHECK-LABEL: @shape_of_dyn +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) +func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C5:.*]] = constant 5 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> + // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor + return +} + +// ----- + +// CHECK-LABEL: @shape_eq +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor) -> i1 +func @shape_eq(%a : tensor, %b : tensor) -> i1 { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor + // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor + // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]] + // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[INIT:.*]] = constant true + // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { + // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor + // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor + // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]] + // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] + // CHECK: scf.yield %[[CONJ_NEXT]] : i1 + // CHECK: } + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } else { + // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } + // CHECK: return %[[SHAPE_EQ]] : i1 + %result = shape.shape_eq %a, %b : tensor, tensor + return %result : i1 +} + +// ----- + +// Don't lower `shape.broadcast` if a `shape.shape` type is involved. +// CHECK-LABEL: @broadcast +func @broadcast(%a : tensor, %b : !shape.shape) -> !shape.shape { + // CHECK: shape.broadcast + %c = shape.broadcast %a, %b : tensor, !shape.shape -> !shape.shape + return %c : !shape.shape +} + +// ----- + +// CHECK-LABEL: @broadcast +// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) +func @broadcast(%a : tensor, %b : tensor) { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor + // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor + // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] + // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor, index, tensor) { + // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor, index, tensor + // CHECK: } else { + // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor, index, tensor + // CHECK: } + // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref + // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index + // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { + // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor + // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] { + // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor + // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { + // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index + // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor + // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index + // CHECK: } else { + // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index + // CHECK: } + // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref + %0 = shape.broadcast %a, %b + : tensor, tensor -> tensor + return +} +